|
@@ -533,13 +533,13 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
logger.error("Failed to load state from peer, received parameters, extras or metadata.")
|
|
|
return
|
|
|
|
|
|
- try:
|
|
|
- load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
|
|
|
- except StopIteration:
|
|
|
- logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
|
|
|
- return
|
|
|
+ with torch.no_grad(), self.lock_averaged_tensors:
|
|
|
+ try:
|
|
|
+ load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
|
|
|
+ except StopIteration:
|
|
|
+ logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
|
|
|
+ return
|
|
|
|
|
|
- with torch.no_grad():
|
|
|
for local_param, loaded_param in zip(main_parameters_and_extras, loaded_parameters_and_extras):
|
|
|
local_param.copy_(loaded_param, non_blocking=True)
|
|
|
|