justheuristic 3 년 전
부모
커밋
d96f56771d
1개의 변경된 파일6개의 추가작업 그리고 6개의 파일을 삭제
  1. 6 6
      hivemind/optim/collaborative.py

+ 6 - 6
hivemind/optim/collaborative.py

@@ -149,7 +149,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
-        self.local_steps_accumulated = 0  # a number of calls to step() since last optimizer update
+        self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
 
@@ -205,7 +205,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     logger.exception(f"Failed to load state from peers: {e}, retrying ...")
                     continue
 
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.reset_accumulated_grads_()
             self.update_scheduler()
 
@@ -241,7 +241,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
-            self.local_steps_accumulated += 1
+            self.local_updates_accumulated += 1
             self.performance_ema.update(num_processed=batch_size)
             self.should_report_progress.set()
 
@@ -254,7 +254,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.collaboration_state_updated.set()
 
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
-            self.apply_accumulated_grads_(scale_by=1.0 / self.local_steps_accumulated)
+            self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
             current_step, group_info = self.averager.local_step, None
 
             if self.collaboration_state.num_peers > 1:
@@ -285,7 +285,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
             self.opt.step()
             self.reset_accumulated_grads_()
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.collaboration_state.register_step(current_step + 1)
             self.averager.local_step = current_step + 1
             self.collaboration_state_updated.set()
@@ -426,7 +426,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         if not isinstance(response, dict) or len(response) == 0:
             logger.log(self.status_loglevel, f"Found no active peers: {response}")
             local_eta_next_step = (
-                max(0, self.target_batch_size - self.local_steps_accumulated) / self.performance_ema.samples_per_second
+                    max(0, self.target_batch_size - self.local_updates_accumulated) / self.performance_ema.samples_per_second
             )
             return CollaborationState(
                 self.local_step,