Procházet zdrojové kódy

sync peer in constructor id needed

xtinkt před 4 roky
rodič
revize
be77054fad
1 změnil soubory, kde provedl 8 přidání a 4 odebrání
  1. 8 4
      hivemind/optim/averaged.py

+ 8 - 4
hivemind/optim/averaged.py

@@ -76,6 +76,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         self.lock_scheduler_params = Lock()
         self.training_state = TrainingState(max_epoch=0, total_steps=0)
         self._fetch_training_state()
+        self._sync_if_needed()
 
         self.background_averaging_thread = Thread(
             name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
@@ -89,10 +90,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         self.background_fetch_training_state.start()
 
     def step(self, *args, **kwargs):
-        if not self.is_synchronized:
-            logger.warning("Peer is out of sync.")
-            self.load_states_from_peers(**kwargs)
-            return
+        self._sync_if_needed()
 
         with self.lock_scheduler_params:
             if self.local_epoch < self.training_state.max_epoch:
@@ -118,6 +116,12 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
 
         return step_result
 
+    def _sync_if_needed(self):
+        if not self.is_synchronized:
+            logger.warning("Peer is out of sync.")
+            self.load_states_from_peers(**kwargs)
+            return
+
     def zero_grad(self, *args, **kwargs):
         return self.opt.zero_grad(*args, **kwargs)