|
@@ -331,12 +331,14 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
if self.finished_averaging_round.is_set():
|
|
|
if not self.reuse_tensors:
|
|
|
self._apply_averaging_results_()
|
|
|
+ if self.offload_optimizer:
|
|
|
+ self._apply_optimizer_parameters_()
|
|
|
logger.log(self.status_loglevel, "Received parameters from background averaging round")
|
|
|
self.finished_averaging_round.clear()
|
|
|
|
|
|
if self.finished_optimizer_step.is_set():
|
|
|
if self.offload_optimizer:
|
|
|
- self._apply_optimizer_results_()
|
|
|
+ self._apply_optimizer_parameters_()
|
|
|
logger.log(self.status_loglevel, "Received parameters from background optimizer step")
|
|
|
self.finished_optimizer_step.clear()
|
|
|
|
|
@@ -362,7 +364,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
self.finished_optimizer_step.wait()
|
|
|
self.finished_optimizer_step.clear()
|
|
|
if self.offload_optimizer:
|
|
|
- self._apply_optimizer_results_()
|
|
|
+ self._apply_optimizer_parameters_()
|
|
|
logger.log(self.status_loglevel, "Finished optimizer step")
|
|
|
|
|
|
if averaging_round and not delay_averaging:
|
|
@@ -450,7 +452,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
opt_param.grad.copy_(main_param.grad, non_blocking=True)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
- def _apply_optimizer_results_(self):
|
|
|
+ def _apply_optimizer_parameters_(self):
|
|
|
"""Copy parameters from offloaded optimizer to the main model"""
|
|
|
assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
|
|
|
with self.lock_averaged_tensors:
|