|
@@ -546,10 +546,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
local_param.copy_(loaded_param, non_blocking=True)
|
|
|
|
|
|
if self.offload_optimizer:
|
|
|
- optimized_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
|
|
|
- loaded_parameters = loaded_parameters_and_extras[: len(optimized_parameters)]
|
|
|
- for local_param, loaded_param in zip(optimized_parameters, loaded_parameters):
|
|
|
- local_param.copy_(loaded_param, non_blocking=True)
|
|
|
+ self._apply_optimizer_parameters_()
|
|
|
|
|
|
self.local_epoch = metadata["epoch"]
|
|
|
self._update_scheduler()
|