Explorar o código

Hotfix: offload_optimizer in load_state_from_peers (#417)

hivemind.optim.experimental.Optimizer with offload_optimizer=True behaved incorrectly when loading state from peers.

It would load the state into local parameters, and then it was meant to write new parameters into the offloaded optimizer, but actually overriden the newly loaded parameters with old offloaded ones. The PR fixes this.
justheuristic %!s(int64=3) %!d(string=hai) anos
pai
achega
318bb7ad48
Modificáronse 1 ficheiros con 4 adicións e 1 borrados
  1. 4 1
      hivemind/optim/experimental/state_averager.py

+ 4 - 1
hivemind/optim/experimental/state_averager.py

@@ -631,7 +631,8 @@ class TrainingStateAverager(DecentralizedAverager):
         Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
         :returns: whether or the averager succeeded in loading parameters
         """
-        main_parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
+        opt_parameters = tuple(param for param_group in self.optimizer.param_groups for param in param_group["params"])
+        main_parameters_and_extras = tuple(chain(opt_parameters, self.extra_tensors))
         num_parameters_and_extras = len(main_parameters_and_extras)
 
         loaded_state = super().load_state_from_peers(**kwargs)
@@ -661,6 +662,8 @@ class TrainingStateAverager(DecentralizedAverager):
 
         if self.offload_optimizer:
             self._apply_optimizer_parameters_()
+        if not self.reuse_tensors:
+            self._load_local_tensors_into_averager_()
 
         self.local_epoch = metadata["epoch"]
         self._update_scheduler()