|
@@ -112,7 +112,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
|
|
|
|
|
|
def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
|
|
|
control: transformers.TrainerControl, **kwargs):
|
|
|
- logger.warning('Loading state from peers')
|
|
|
+ logger.info('Loading state from peers')
|
|
|
self.collaborative_optimizer.load_state_from_peers()
|
|
|
|
|
|
def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
|
|
@@ -139,14 +139,15 @@ class CollaborativeCallback(transformers.TrainerCallback):
|
|
|
logger.info(f"Step {self.collaborative_optimizer.local_step}")
|
|
|
logger.info(f"Your current contribution: {self.total_samples_processed} samples")
|
|
|
if self.steps:
|
|
|
- logger.info(f"Loss of your model: {self.loss/self.steps}")
|
|
|
+ logger.info(f"Local loss: {self.loss / self.steps}")
|
|
|
|
|
|
self.loss = 0
|
|
|
self.steps = 0
|
|
|
- self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
|
|
|
- subkey=self.local_public_key, value=statistics.dict(),
|
|
|
- expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
|
|
|
- return_future=True)
|
|
|
+ if self.collaborative_optimizer.is_synchronized:
|
|
|
+ self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
|
|
|
+ subkey=self.local_public_key, value=statistics.dict(),
|
|
|
+ expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
|
|
|
+ return_future=True)
|
|
|
|
|
|
self.samples = self.collaborative_optimizer.local_samples_accumulated
|
|
|
|