Michael Diskin 4 лет назад
Родитель
Сommit
2314e7ebd5
2 измененных файлов с 8 добавлено и 5 удалено
  1. 4 2
      examples/albert/run_first_peer.py
  2. 4 3
      examples/albert/run_trainer.py

+ 4 - 2
examples/albert/run_first_peer.py

@@ -51,13 +51,15 @@ if __name__ == '__main__':
                 sum_loss = 0
                 num_samples = 0
                 sum_perf = 0
-                for step, perf, samples, loss in metrics:
+                sum_mini_steps = 0
+                for step, perf, samples, loss, mini_steps in metrics:
                     sum_loss += loss
                     alive_peers += 1
                     sum_perf += perf
                     num_samples += samples
+                    sum_mini_steps += mini_steps
                 wandb.log({
-                    "loss": sum_loss / alive_peers,
+                    "loss": sum_loss / sum_mini_steps,
                     "alive peers": alive_peers,
                     "samples": num_samples,
                     "performance": sum_perf

+ 4 - 3
examples/albert/run_trainer.py

@@ -186,20 +186,21 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
         if state.log_history:
             self.loss += state.log_history[-1]['loss']
+            self.steps += 1
             if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
                 self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
 
                 statistics = [self.collaborative_optimizer.local_step,
                               self.collaborative_optimizer.performance_ema.samples_per_second,
                               self.samples,
-                              self.loss / self.steps if self.steps else 0]
+                              self.loss,
+                              self.steps]
                 self.loss = 0
-
+                self.steps = 0
                 self.dht.store(self.collaborative_optimizer.prefix + "_metrics", subkey=self.trainer_uuid,
                                value=statistics, expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
                                return_future=True)
         self.samples = self.collaborative_optimizer.local_samples_accumulated
-        self.steps = self.collaborative_optimizer.local_steps_accumulated
 
         return control