Jelajahi Sumber

apply to local parameters

justheuristic 3 tahun lalu
induk
melakukan
cf57c29303
1 mengubah file dengan 1 tambahan dan 4 penghapusan
  1. 1 4
      hivemind/optim/experimental/state_averager.py

+ 1 - 4
hivemind/optim/experimental/state_averager.py

@@ -546,10 +546,7 @@ class TrainingStateAverager(DecentralizedAverager):
                 local_param.copy_(loaded_param, non_blocking=True)
 
             if self.offload_optimizer:
-                optimized_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-                loaded_parameters = loaded_parameters_and_extras[: len(optimized_parameters)]
-                for local_param, loaded_param in zip(optimized_parameters, loaded_parameters):
-                    local_param.copy_(loaded_param, non_blocking=True)
+                self._apply_optimizer_parameters_()
 
         self.local_epoch = metadata["epoch"]
         self._update_scheduler()