justheuristic 3 vuotta sitten
vanhempi
commit
024b339b71
1 muutettua tiedostoa jossa 1 lisäystä ja 1 poistoa
  1. 1 1
      hivemind/optim/experimental/state_averager.py

+ 1 - 1
hivemind/optim/experimental/state_averager.py

@@ -155,7 +155,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
     def _make_averaged_parameters(self, main_parameters: Sequence[torch.Tensor]):
         """Initialize averaged parameters based on the optimizer and averaging mode"""
-        return tuple(self._make_host_tensor(param, force_copy=True) for param in main_parameters)
+        return tuple(self._make_host_tensor(param, force_copy=self.offload_optimizer) for param in main_parameters)
 
     def _make_host_tensor(self, source_tensor: torch.Tensor, force_copy: bool = False) -> torch.Tensor:
         """Create a new tensor for averaging or reuse the existing one"""