Prechádzať zdrojové kódy

Log more stats for user, move performance stats to examples (#257)

* refactor contribution logging

* add more logs
Michael Diskin 4 rokov pred
rodič
commit
afc59d2a6b

+ 8 - 1
examples/albert/run_trainer.py

@@ -108,6 +108,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
         self.samples = 0
         self.steps = 0
         self.loss = 0
+        self.total_samples_processed = 0
 
     def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
                        control: transformers.TrainerControl, **kwargs):
@@ -127,7 +128,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
             self.steps += 1
             if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
                 self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
-
+                self.total_samples_processed += self.samples
                 samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
                 statistics = metrics_utils.LocalMetrics(
                     step=self.collaborative_optimizer.local_step,
@@ -135,12 +136,18 @@ class CollaborativeCallback(transformers.TrainerCallback):
                     samples_accumulated=self.samples,
                     loss=self.loss,
                     mini_steps=self.steps)
+                logger.info(f"Step {self.collaborative_optimizer.local_step}")
+                logger.info(f"Your current contribution: {self.total_samples_processed} samples")
+                if self.steps:
+                    logger.info(f"Loss of your model: {self.loss/self.steps}")
+
                 self.loss = 0
                 self.steps = 0
                 self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
                                subkey=self.local_public_key, value=statistics.dict(),
                                expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
                                return_future=True)
+
         self.samples = self.collaborative_optimizer.local_samples_accumulated
 
         return control

+ 0 - 3
hivemind/optim/collaborative.py

@@ -127,7 +127,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
         self.local_steps_accumulated = 0  # a number of calls to step() since last optimizer update
-        self.samples_processed = 0
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
 
@@ -192,7 +191,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
-            self.samples_processed += batch_size
             self.performance_ema.update(num_processed=self.batch_size_per_step)
             self.should_report_progress.set()
 
@@ -235,7 +233,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.update_scheduler()
 
             logger.log(self.status_loglevel, f"Optimizer step: done!")
-            logger.info(f"Your current contribution: {self.samples_processed} samples")
 
             return group_info