浏览代码

apply to local parameters

justheuristic 3 年之前
父节点
当前提交
cf57c29303
共有 1 个文件被更改,包括 1 次插入4 次删除
  1. 1 4
      hivemind/optim/experimental/state_averager.py

+ 1 - 4
hivemind/optim/experimental/state_averager.py

@@ -546,10 +546,7 @@ class TrainingStateAverager(DecentralizedAverager):
                 local_param.copy_(loaded_param, non_blocking=True)
 
             if self.offload_optimizer:
-                optimized_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-                loaded_parameters = loaded_parameters_and_extras[: len(optimized_parameters)]
-                for local_param, loaded_param in zip(optimized_parameters, loaded_parameters):
-                    local_param.copy_(loaded_param, non_blocking=True)
+                self._apply_optimizer_parameters_()
 
         self.local_epoch = metadata["epoch"]
         self._update_scheduler()