浏览代码

sync peer in constructor id needed

xtinkt 4 年之前
父节点
当前提交
be77054fad
共有 1 个文件被更改,包括 8 次插入4 次删除
  1. 8 4
      hivemind/optim/averaged.py

+ 8 - 4
hivemind/optim/averaged.py

@@ -76,6 +76,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         self.lock_scheduler_params = Lock()
         self.training_state = TrainingState(max_epoch=0, total_steps=0)
         self._fetch_training_state()
+        self._sync_if_needed()
 
         self.background_averaging_thread = Thread(
             name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
@@ -89,10 +90,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         self.background_fetch_training_state.start()
 
     def step(self, *args, **kwargs):
-        if not self.is_synchronized:
-            logger.warning("Peer is out of sync.")
-            self.load_states_from_peers(**kwargs)
-            return
+        self._sync_if_needed()
 
         with self.lock_scheduler_params:
             if self.local_epoch < self.training_state.max_epoch:
@@ -118,6 +116,12 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
 
         return step_result
 
+    def _sync_if_needed(self):
+        if not self.is_synchronized:
+            logger.warning("Peer is out of sync.")
+            self.load_states_from_peers(**kwargs)
+            return
+
     def zero_grad(self, *args, **kwargs):
         return self.opt.zero_grad(*args, **kwargs)