Browse Source

lock tensors

justheuristic 3 năm trước cách đây
mục cha
commit
ac5e6cc664
1 tập tin đã thay đổi với 1 bổ sung1 xóa
  1. 1 1
      hivemind/optim/experimental/state_averager.py

+ 1 - 1
hivemind/optim/experimental/state_averager.py

@@ -484,7 +484,7 @@ class TrainingStateAverager(DecentralizedAverager):
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
         :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
         """
-        with torch.no_grad():
+        with torch.no_grad(), self.lock_averaged_tensors:
             optimized_parameters = tuple(
                 param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
             )