|
@@ -76,6 +76,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
self.lock_scheduler_params = Lock()
|
|
|
self.training_state = TrainingState(max_epoch=0, total_steps=0)
|
|
|
self._fetch_training_state()
|
|
|
+ self._sync_if_needed()
|
|
|
|
|
|
self.background_averaging_thread = Thread(
|
|
|
name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
|
|
@@ -89,10 +90,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
self.background_fetch_training_state.start()
|
|
|
|
|
|
def step(self, *args, **kwargs):
|
|
|
- if not self.is_synchronized:
|
|
|
- logger.warning("Peer is out of sync.")
|
|
|
- self.load_states_from_peers(**kwargs)
|
|
|
- return
|
|
|
+ self._sync_if_needed()
|
|
|
|
|
|
with self.lock_scheduler_params:
|
|
|
if self.local_epoch < self.training_state.max_epoch:
|
|
@@ -118,6 +116,12 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
return step_result
|
|
|
|
|
|
+ def _sync_if_needed(self):
|
|
|
+ if not self.is_synchronized:
|
|
|
+ logger.warning("Peer is out of sync.")
|
|
|
+ self.load_states_from_peers(**kwargs)
|
|
|
+ return
|
|
|
+
|
|
|
def zero_grad(self, *args, **kwargs):
|
|
|
return self.opt.zero_grad(*args, **kwargs)
|
|
|
|