Selaa lähdekoodia

apply optimizer results faster

justheuristic 3 vuotta sitten
vanhempi
commit
4e4343d2f0
1 muutettua tiedostoa jossa 5 lisäystä ja 3 poistoa
  1. 5 3
      hivemind/optim/experimental/state_averager.py

+ 5 - 3
hivemind/optim/experimental/state_averager.py

@@ -331,12 +331,14 @@ class TrainingStateAverager(DecentralizedAverager):
             if self.finished_averaging_round.is_set():
                 if not self.reuse_tensors:
                     self._apply_averaging_results_()
+                    if self.offload_optimizer:
+                        self._apply_optimizer_parameters_()
                 logger.log(self.status_loglevel, "Received parameters from background averaging round")
                 self.finished_averaging_round.clear()
 
             if self.finished_optimizer_step.is_set():
                 if self.offload_optimizer:
-                    self._apply_optimizer_results_()
+                    self._apply_optimizer_parameters_()
                 logger.log(self.status_loglevel, "Received parameters from background optimizer step")
                 self.finished_optimizer_step.clear()
 
@@ -362,7 +364,7 @@ class TrainingStateAverager(DecentralizedAverager):
                 self.finished_optimizer_step.wait()
                 self.finished_optimizer_step.clear()
                 if self.offload_optimizer:
-                    self._apply_optimizer_results_()
+                    self._apply_optimizer_parameters_()
                 logger.log(self.status_loglevel, "Finished optimizer step")
 
             if averaging_round and not delay_averaging:
@@ -450,7 +452,7 @@ class TrainingStateAverager(DecentralizedAverager):
                 opt_param.grad.copy_(main_param.grad, non_blocking=True)
 
     @torch.no_grad()
-    def _apply_optimizer_results_(self):
+    def _apply_optimizer_parameters_(self):
         """Copy parameters from offloaded optimizer to the main model"""
         assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
         with self.lock_averaged_tensors: