justheuristic пре 3 година
родитељ
комит
da21ae3034
1 измењених фајлова са 6 додато и 6 уклоњено
  1. 6 6
      hivemind/optim/experimental/state_averager.py

+ 6 - 6
hivemind/optim/experimental/state_averager.py

@@ -533,13 +533,13 @@ class TrainingStateAverager(DecentralizedAverager):
             logger.error("Failed to load state from peer, received parameters, extras or metadata.")
             return
 
-        try:
-            load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
-        except StopIteration:
-            logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
-            return
+        with torch.no_grad(), self.lock_averaged_tensors:
+            try:
+                load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
+            except StopIteration:
+                logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
+                return
 
-        with torch.no_grad():
             for local_param, loaded_param in zip(main_parameters_and_extras, loaded_parameters_and_extras):
                 local_param.copy_(loaded_param, non_blocking=True)