浏览代码

init averaged parameters

justheuristic 3 年之前
父节点
当前提交
3fd11f19e9
共有 1 个文件被更改,包括 7 次插入3 次删除
  1. 7 3
      hivemind/optim/experimental/state_averager.py

+ 7 - 3
hivemind/optim/experimental/state_averager.py

@@ -110,7 +110,7 @@ class TrainingStateAverager(DecentralizedAverager):
         self._old_tensors: Optional[Sequence[torch.Tensor]] = None  # for delta rule
 
         self.main_parameters, self.parameter_names = main_parameters, parameter_names
-        self._averaged_parameters = tuple(self._make_host_tensor(p, force_copy=True) for p in self.main_parameters)
+        self._averaged_parameters = self._make_averaged_parameters(main_parameters)
         self.optimizer, self.scheduler = self._init_components(
             param_groups, optimizer, scheduler, initialize_optimizer
         )
@@ -153,9 +153,13 @@ class TrainingStateAverager(DecentralizedAverager):
         assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
         return param_groups, parameters, parameter_names
 
-    def _make_host_tensor(self, source_tensor: torch.Tensor, force_copy: bool = False) -> torch.Tensor:
+    def _make_averaged_parameters(self, main_parameters: Sequence[torch.Tensor]):
+        """Initialize averaged parameters based on the optimizer and averaging mode"""
+        return tuple(self._make_host_tensor(p.cpu() if self.offload_optimizer else p) for p in self.main_parameters)
+
+    def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
         """Create a new tensor for averaging or reuse the existing one"""
-        if self.reuse_tensors and not force_copy:
+        if self.reuse_tensors:
             if source_tensor.device != torch.device("cpu"):
                 raise ValueError("reuse_tensors is only supported if all averaged tensors are on CPU.")
             if not source_tensor.is_shared():