|
@@ -186,20 +186,21 @@ class CollaborativeCallback(transformers.TrainerCallback):
|
|
|
|
|
|
if state.log_history:
|
|
|
self.loss += state.log_history[-1]['loss']
|
|
|
+ self.steps += 1
|
|
|
if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
|
|
|
self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
|
|
|
|
|
|
statistics = [self.collaborative_optimizer.local_step,
|
|
|
self.collaborative_optimizer.performance_ema.samples_per_second,
|
|
|
self.samples,
|
|
|
- self.loss / self.steps if self.steps else 0]
|
|
|
+ self.loss,
|
|
|
+ self.steps]
|
|
|
self.loss = 0
|
|
|
-
|
|
|
+ self.steps = 0
|
|
|
self.dht.store(self.collaborative_optimizer.prefix + "_metrics", subkey=self.trainer_uuid,
|
|
|
value=statistics, expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
|
|
|
return_future=True)
|
|
|
self.samples = self.collaborative_optimizer.local_samples_accumulated
|
|
|
- self.steps = self.collaborative_optimizer.local_steps_accumulated
|
|
|
|
|
|
return control
|
|
|
|