|
@@ -484,7 +484,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
|
|
|
:returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
|
|
|
"""
|
|
|
- with torch.no_grad():
|
|
|
+ with torch.no_grad(), self.lock_averaged_tensors:
|
|
|
optimized_parameters = tuple(
|
|
|
param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
|
|
|
)
|