justheuristic před 3 roky
rodič
revize
56bdc5c5d4
1 změnil soubory, kde provedl 3 přidání a 3 odebrání
  1. 3 3
      hivemind/optim/experimental/state_averager.py

+ 3 - 3
hivemind/optim/experimental/state_averager.py

@@ -110,7 +110,7 @@ class TrainingStateAverager(DecentralizedAverager):
         self._old_tensors: Optional[Sequence[torch.Tensor]] = None  # for delta rule
 
         self.main_parameters, self.parameter_names = main_parameters, parameter_names
-        self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
+        self._averaged_parameters = tuple(self._make_host_tensor(p, force_copy=True) for p in self.main_parameters)
         self.optimizer, self.scheduler = self._init_components(
             param_groups, optimizer, scheduler, initialize_optimizer
         )
@@ -153,9 +153,9 @@ class TrainingStateAverager(DecentralizedAverager):
         assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
         return param_groups, parameters, parameter_names
 
-    def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
+    def _make_host_tensor(self, source_tensor: torch.Tensor, force_copy: bool = False) -> torch.Tensor:
         """Create a new tensor for averaging or reuse the existing one"""
-        if self.reuse_tensors:
+        if self.reuse_tensors and not force_copy:
             if source_tensor.device != torch.device("cpu"):
                 raise ValueError("reuse_tensors is only supported if all averaged tensors are on CPU.")
             if not source_tensor.is_shared():