|
@@ -558,13 +558,12 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
def _apply_optimizer_parameters_(self):
|
|
|
"""Copy parameters from offloaded optimizer to the main model"""
|
|
|
assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
|
|
|
- with self.lock_averaged_tensors:
|
|
|
- offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
|
|
|
- assert len(offloaded_parameters) == len(
|
|
|
- self.main_parameters
|
|
|
- ), "Optimizer parameters changed during training"
|
|
|
- for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
|
|
|
- main_param.copy_(offloaded_param, non_blocking=True)
|
|
|
+ offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
|
|
|
+ assert len(offloaded_parameters) == len(
|
|
|
+ self.main_parameters
|
|
|
+ ), "Optimizer parameters changed during training"
|
|
|
+ for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
|
|
|
+ main_param.copy_(offloaded_param, non_blocking=True)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def _load_local_tensors_into_averager_(self):
|