Răsfoiți Sursa

load params into offloaded optimizer

justheuristic 3 ani în urmă
părinte
comite
ac8fb55af7
1 a modificat fișierele cu 10 adăugiri și 3 ștergeri
  1. 10 3
      hivemind/optim/experimental/state_averager.py

+ 10 - 3
hivemind/optim/experimental/state_averager.py

@@ -512,8 +512,8 @@ class TrainingStateAverager(DecentralizedAverager):
         Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
         :returns: whether or the averager succeeded in loading parameters
         """
-        parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
-        num_parameters_and_extras = len(parameters_and_extras)
+        main_parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
+        num_parameters_and_extras = len(main_parameters_and_extras)
 
         loaded_state = super().load_state_from_peers(**kwargs)
         if loaded_state is None:
@@ -537,8 +537,15 @@ class TrainingStateAverager(DecentralizedAverager):
             return
 
         with torch.no_grad():
-            for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
+            for local_param, loaded_param in zip(main_parameters_and_extras, loaded_parameters_and_extras):
                 local_param.copy_(loaded_param, non_blocking=True)
+
+            if self.offload_optimizer:
+                optimized_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+                loaded_parameters = loaded_parameters_and_extras[:len(optimized_parameters)]
+                for local_param, loaded_param in zip(optimized_parameters, loaded_parameters):
+                    local_param.copy_(loaded_param, non_blocking=True)
+
         self.local_epoch = metadata["epoch"]
         self._update_scheduler()