浏览代码

Hotfix: offload_optimizer in load_state_from_peers (#417)

hivemind.optim.experimental.Optimizer with offload_optimizer=True behaved incorrectly when loading state from peers.

It would load the state into local parameters, and then it was meant to write new parameters into the offloaded optimizer, but actually overriden the newly loaded parameters with old offloaded ones. The PR fixes this.
justheuristic 3 年之前
父节点
当前提交
318bb7ad48
共有 1 个文件被更改,包括 4 次插入1 次删除
  1. 4 1
      hivemind/optim/experimental/state_averager.py

+ 4 - 1
hivemind/optim/experimental/state_averager.py

@@ -631,7 +631,8 @@ class TrainingStateAverager(DecentralizedAverager):
         Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
         :returns: whether or the averager succeeded in loading parameters
         """
-        main_parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
+        opt_parameters = tuple(param for param_group in self.optimizer.param_groups for param in param_group["params"])
+        main_parameters_and_extras = tuple(chain(opt_parameters, self.extra_tensors))
         num_parameters_and_extras = len(main_parameters_and_extras)
 
         loaded_state = super().load_state_from_peers(**kwargs)
@@ -661,6 +662,8 @@ class TrainingStateAverager(DecentralizedAverager):
 
         if self.offload_optimizer:
             self._apply_optimizer_parameters_()
+        if not self.reuse_tensors:
+            self._load_local_tensors_into_averager_()
 
         self.local_epoch = metadata["epoch"]
         self._update_scheduler()