|
@@ -631,7 +631,8 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
|
|
Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
|
|
:returns: whether or the averager succeeded in loading parameters
|
|
:returns: whether or the averager succeeded in loading parameters
|
|
"""
|
|
"""
|
|
- main_parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
|
|
|
|
|
|
+ opt_parameters = tuple(param for param_group in self.optimizer.param_groups for param in param_group["params"])
|
|
|
|
+ main_parameters_and_extras = tuple(chain(opt_parameters, self.extra_tensors))
|
|
num_parameters_and_extras = len(main_parameters_and_extras)
|
|
num_parameters_and_extras = len(main_parameters_and_extras)
|
|
|
|
|
|
loaded_state = super().load_state_from_peers(**kwargs)
|
|
loaded_state = super().load_state_from_peers(**kwargs)
|
|
@@ -661,6 +662,8 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
|
|
|
if self.offload_optimizer:
|
|
if self.offload_optimizer:
|
|
self._apply_optimizer_parameters_()
|
|
self._apply_optimizer_parameters_()
|
|
|
|
+ if not self.reuse_tensors:
|
|
|
|
+ self._load_local_tensors_into_averager_()
|
|
|
|
|
|
self.local_epoch = metadata["epoch"]
|
|
self.local_epoch = metadata["epoch"]
|
|
self._update_scheduler()
|
|
self._update_scheduler()
|