|
@@ -512,8 +512,8 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
|
|
|
:returns: whether or the averager succeeded in loading parameters
|
|
|
"""
|
|
|
- parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
|
|
|
- num_parameters_and_extras = len(parameters_and_extras)
|
|
|
+ main_parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
|
|
|
+ num_parameters_and_extras = len(main_parameters_and_extras)
|
|
|
|
|
|
loaded_state = super().load_state_from_peers(**kwargs)
|
|
|
if loaded_state is None:
|
|
@@ -537,8 +537,15 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
return
|
|
|
|
|
|
with torch.no_grad():
|
|
|
- for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
|
|
|
+ for local_param, loaded_param in zip(main_parameters_and_extras, loaded_parameters_and_extras):
|
|
|
local_param.copy_(loaded_param, non_blocking=True)
|
|
|
+
|
|
|
+ if self.offload_optimizer:
|
|
|
+ optimized_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
|
|
|
+ loaded_parameters = loaded_parameters_and_extras[:len(optimized_parameters)]
|
|
|
+ for local_param, loaded_param in zip(optimized_parameters, loaded_parameters):
|
|
|
+ local_param.copy_(loaded_param, non_blocking=True)
|
|
|
+
|
|
|
self.local_epoch = metadata["epoch"]
|
|
|
self._update_scheduler()
|
|
|
|