|
@@ -108,6 +108,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
|
|
|
self.samples = 0
|
|
|
self.steps = 0
|
|
|
self.loss = 0
|
|
|
+ self.total_samples_processed = 0
|
|
|
|
|
|
def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
|
|
|
control: transformers.TrainerControl, **kwargs):
|
|
@@ -127,7 +128,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
|
|
|
self.steps += 1
|
|
|
if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
|
|
|
self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
|
|
|
-
|
|
|
+ self.total_samples_processed += self.samples
|
|
|
samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
|
|
|
statistics = metrics_utils.LocalMetrics(
|
|
|
step=self.collaborative_optimizer.local_step,
|
|
@@ -135,12 +136,18 @@ class CollaborativeCallback(transformers.TrainerCallback):
|
|
|
samples_accumulated=self.samples,
|
|
|
loss=self.loss,
|
|
|
mini_steps=self.steps)
|
|
|
+ logger.info(f"Step {self.collaborative_optimizer.local_step}")
|
|
|
+ logger.info(f"Your current contribution: {self.total_samples_processed} samples")
|
|
|
+ if self.steps:
|
|
|
+ logger.info(f"Loss of your model: {self.loss/self.steps}")
|
|
|
+
|
|
|
self.loss = 0
|
|
|
self.steps = 0
|
|
|
self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
|
|
|
subkey=self.local_public_key, value=statistics.dict(),
|
|
|
expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
|
|
|
return_future=True)
|
|
|
+
|
|
|
self.samples = self.collaborative_optimizer.local_samples_accumulated
|
|
|
|
|
|
return control
|