|
@@ -545,8 +545,8 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
for local_param, loaded_param in zip(main_parameters_and_extras, loaded_parameters_and_extras):
|
|
|
local_param.copy_(loaded_param, non_blocking=True)
|
|
|
|
|
|
- if self.offload_optimizer:
|
|
|
- self._apply_optimizer_parameters_()
|
|
|
+ if self.offload_optimizer:
|
|
|
+ self._apply_optimizer_parameters_()
|
|
|
|
|
|
self.local_epoch = metadata["epoch"]
|
|
|
self._update_scheduler()
|