|
@@ -16,11 +16,17 @@ from transformers.optimization import get_linear_schedule_with_warmup
|
|
from transformers.trainer import Trainer
|
|
from transformers.trainer import Trainer
|
|
from transformers.trainer_utils import is_main_process
|
|
from transformers.trainer_utils import is_main_process
|
|
|
|
|
|
-import hivemind
|
|
|
|
|
|
+from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
|
import utils
|
|
import utils
|
|
-from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
|
|
|
|
|
|
+from arguments import (
|
|
|
|
+ AlbertTrainingArguments,
|
|
|
|
+ AveragerArguments,
|
|
|
|
+ CollaborationArguments,
|
|
|
|
+ DatasetArguments,
|
|
|
|
+ ProgressTrackerArguments,
|
|
|
|
+)
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
use_hivemind_log_handler("in_root_logger")
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
@@ -90,8 +96,8 @@ class CollaborativeCallback(transformers.TrainerCallback):
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
- dht: hivemind.DHT,
|
|
|
|
- optimizer: hivemind.CollaborativeOptimizer,
|
|
|
|
|
|
+ dht: DHT,
|
|
|
|
+ optimizer: Optimizer,
|
|
model: torch.nn.Module,
|
|
model: torch.nn.Module,
|
|
local_public_key: bytes,
|
|
local_public_key: bytes,
|
|
statistics_expiration: float,
|
|
statistics_expiration: float,
|
|
@@ -99,7 +105,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
|
|
):
|
|
):
|
|
super().__init__()
|
|
super().__init__()
|
|
self.model = model
|
|
self.model = model
|
|
- self.dht, self.collaborative_optimizer = dht, optimizer
|
|
|
|
|
|
+ self.dht, self.optimizer = dht, optimizer
|
|
self.local_public_key = local_public_key
|
|
self.local_public_key = local_public_key
|
|
self.statistics_expiration = statistics_expiration
|
|
self.statistics_expiration = statistics_expiration
|
|
self.last_reported_collaboration_step = -1
|
|
self.last_reported_collaboration_step = -1
|
|
@@ -114,7 +120,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
|
|
self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
|
|
self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
|
|
):
|
|
):
|
|
logger.info("Loading state from peers")
|
|
logger.info("Loading state from peers")
|
|
- self.collaborative_optimizer.load_state_from_peers()
|
|
|
|
|
|
+ self.optimizer.load_state_from_peers()
|
|
|
|
|
|
def on_step_end(
|
|
def on_step_end(
|
|
self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
|
|
self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
|
|
@@ -124,40 +130,43 @@ class CollaborativeCallback(transformers.TrainerCallback):
|
|
self.restore_from_backup(self.latest_backup)
|
|
self.restore_from_backup(self.latest_backup)
|
|
return control
|
|
return control
|
|
|
|
|
|
|
|
+ local_progress = self.optimizer.local_progress
|
|
|
|
+
|
|
if state.log_history:
|
|
if state.log_history:
|
|
self.loss += state.log_history[-1]["loss"]
|
|
self.loss += state.log_history[-1]["loss"]
|
|
self.steps += 1
|
|
self.steps += 1
|
|
- if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
|
|
|
|
- self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
|
|
|
|
|
|
+
|
|
|
|
+ if self.optimizer.local_epoch != self.last_reported_collaboration_step:
|
|
|
|
+ self.last_reported_collaboration_step = self.optimizer.local_epoch
|
|
self.total_samples_processed += self.samples
|
|
self.total_samples_processed += self.samples
|
|
- samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
|
|
|
|
|
|
+ samples_per_second = local_progress.samples_per_second
|
|
statistics = utils.LocalMetrics(
|
|
statistics = utils.LocalMetrics(
|
|
- step=self.collaborative_optimizer.local_step,
|
|
|
|
|
|
+ step=self.optimizer.local_epoch,
|
|
samples_per_second=samples_per_second,
|
|
samples_per_second=samples_per_second,
|
|
samples_accumulated=self.samples,
|
|
samples_accumulated=self.samples,
|
|
loss=self.loss,
|
|
loss=self.loss,
|
|
mini_steps=self.steps,
|
|
mini_steps=self.steps,
|
|
)
|
|
)
|
|
- logger.info(f"Step #{self.collaborative_optimizer.local_step}")
|
|
|
|
|
|
+ logger.info(f"Step #{self.optimizer.local_epoch}")
|
|
logger.info(f"Your current contribution: {self.total_samples_processed} samples")
|
|
logger.info(f"Your current contribution: {self.total_samples_processed} samples")
|
|
logger.info(f"Performance: {samples_per_second} samples per second.")
|
|
logger.info(f"Performance: {samples_per_second} samples per second.")
|
|
if self.steps:
|
|
if self.steps:
|
|
logger.info(f"Local loss: {self.loss / self.steps}")
|
|
logger.info(f"Local loss: {self.loss / self.steps}")
|
|
- if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
|
|
|
|
|
|
+ if self.optimizer.local_epoch % self.backup_every_steps == 0:
|
|
self.latest_backup = self.backup_state()
|
|
self.latest_backup = self.backup_state()
|
|
|
|
|
|
self.loss = 0
|
|
self.loss = 0
|
|
self.steps = 0
|
|
self.steps = 0
|
|
- if self.collaborative_optimizer.is_synchronized:
|
|
|
|
|
|
+ if self.optimizer.is_synchronized_with_peers():
|
|
self.dht.store(
|
|
self.dht.store(
|
|
- key=self.collaborative_optimizer.prefix + "_metrics",
|
|
|
|
|
|
+ key=self.optimizer.run_id + "_metrics",
|
|
subkey=self.local_public_key,
|
|
subkey=self.local_public_key,
|
|
value=statistics.dict(),
|
|
value=statistics.dict(),
|
|
- expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
|
|
|
|
|
|
+ expiration_time=get_dht_time() + self.statistics_expiration,
|
|
return_future=True,
|
|
return_future=True,
|
|
)
|
|
)
|
|
|
|
|
|
- self.samples = self.collaborative_optimizer.local_samples_accumulated
|
|
|
|
|
|
+ self.samples = local_progress.samples_accumulated
|
|
|
|
|
|
return control
|
|
return control
|
|
|
|
|
|
@@ -170,19 +179,17 @@ class CollaborativeCallback(transformers.TrainerCallback):
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.no_grad()
|
|
def backup_state(self) -> bytes:
|
|
def backup_state(self) -> bytes:
|
|
- return pickle.dumps(
|
|
|
|
- {"model": self.model.state_dict(), "optimizer": self.collaborative_optimizer.opt.state_dict()}
|
|
|
|
- )
|
|
|
|
|
|
+ return pickle.dumps({"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()})
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.no_grad()
|
|
def restore_from_backup(self, backup: bytes):
|
|
def restore_from_backup(self, backup: bytes):
|
|
state = pickle.loads(backup)
|
|
state = pickle.loads(backup)
|
|
self.model.load_state_dict(state["model"])
|
|
self.model.load_state_dict(state["model"])
|
|
- self.collaborative_optimizer.opt.load_state_dict(state["optimizer"])
|
|
|
|
|
|
+ self.optimizer.load_state_dict(state["optimizer"])
|
|
|
|
|
|
|
|
|
|
class NoOpScheduler(LRSchedulerBase):
|
|
class NoOpScheduler(LRSchedulerBase):
|
|
- """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
|
|
|
|
|
|
+ """Dummy scheduler for transformers.Trainer. The real scheduler is defined in Optimizer.scheduler"""
|
|
|
|
|
|
def get_lr(self):
|
|
def get_lr(self):
|
|
return [group["lr"] for group in self.optimizer.param_groups]
|
|
return [group["lr"] for group in self.optimizer.param_groups]
|
|
@@ -202,8 +209,16 @@ class NoOpScheduler(LRSchedulerBase):
|
|
|
|
|
|
|
|
|
|
def main():
|
|
def main():
|
|
- parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments, AveragerArguments))
|
|
|
|
- training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
|
|
|
|
|
|
+ parser = HfArgumentParser(
|
|
|
|
+ (
|
|
|
|
+ AlbertTrainingArguments,
|
|
|
|
+ DatasetArguments,
|
|
|
|
+ CollaborationArguments,
|
|
|
|
+ AveragerArguments,
|
|
|
|
+ ProgressTrackerArguments,
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
+ training_args, dataset_args, collaboration_args, averager_args, tracker_args = parser.parse_args_into_dataclasses()
|
|
|
|
|
|
logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
|
|
logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
|
|
if len(collaboration_args.initial_peers) == 0:
|
|
if len(collaboration_args.initial_peers) == 0:
|
|
@@ -228,7 +243,7 @@ def main():
|
|
|
|
|
|
validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
|
|
validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
|
|
|
|
|
|
- dht = hivemind.DHT(
|
|
|
|
|
|
+ dht = DHT(
|
|
start=True,
|
|
start=True,
|
|
initial_peers=collaboration_args.initial_peers,
|
|
initial_peers=collaboration_args.initial_peers,
|
|
client_mode=collaboration_args.client_mode,
|
|
client_mode=collaboration_args.client_mode,
|
|
@@ -246,19 +261,24 @@ def main():
|
|
|
|
|
|
adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
|
|
adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
|
|
|
|
|
|
- collaborative_optimizer = hivemind.CollaborativeOptimizer(
|
|
|
|
- opt=opt,
|
|
|
|
|
|
+ optimizer = Optimizer(
|
|
dht=dht,
|
|
dht=dht,
|
|
- scheduler=scheduler,
|
|
|
|
- prefix=collaboration_args.experiment_prefix,
|
|
|
|
- compression=hivemind.Float16Compression(),
|
|
|
|
- batch_size_per_step=total_batch_size_per_step,
|
|
|
|
- bandwidth=collaboration_args.bandwidth,
|
|
|
|
|
|
+ run_id=collaboration_args.experiment_prefix,
|
|
target_batch_size=adjusted_target_batch_size,
|
|
target_batch_size=adjusted_target_batch_size,
|
|
|
|
+ batch_size_per_step=total_batch_size_per_step,
|
|
|
|
+ optimizer=opt,
|
|
|
|
+ scheduler=scheduler,
|
|
|
|
+ matchmaking_time=collaboration_args.matchmaking_time,
|
|
|
|
+ averaging_timeout=collaboration_args.averaging_timeout,
|
|
|
|
+ offload_optimizer=True,
|
|
|
|
+ delay_optimizer_step=True,
|
|
|
|
+ delay_grad_averaging=True,
|
|
client_mode=collaboration_args.client_mode,
|
|
client_mode=collaboration_args.client_mode,
|
|
|
|
+ grad_compression=Float16Compression(),
|
|
|
|
+ state_averaging_compression=Float16Compression(),
|
|
|
|
+ averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)},
|
|
|
|
+ tracker_opts=asdict(tracker_args),
|
|
verbose=True,
|
|
verbose=True,
|
|
- start=True,
|
|
|
|
- **asdict(averager_args),
|
|
|
|
)
|
|
)
|
|
|
|
|
|
class TrainerWithIndependentShuffling(Trainer):
|
|
class TrainerWithIndependentShuffling(Trainer):
|
|
@@ -274,11 +294,11 @@ def main():
|
|
data_collator=data_collator,
|
|
data_collator=data_collator,
|
|
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
|
|
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
|
|
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
|
|
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
|
|
- optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
|
|
|
|
|
|
+ optimizers=(optimizer, NoOpScheduler(optimizer)),
|
|
callbacks=[
|
|
callbacks=[
|
|
CollaborativeCallback(
|
|
CollaborativeCallback(
|
|
dht,
|
|
dht,
|
|
- collaborative_optimizer,
|
|
|
|
|
|
+ optimizer,
|
|
model,
|
|
model,
|
|
local_public_key,
|
|
local_public_key,
|
|
collaboration_args.statistics_expiration,
|
|
collaboration_args.statistics_expiration,
|