Explorar el Código

init averaged parameters

justheuristic hace 3 años
padre
commit
3aa56b328a
Se han modificado 1 ficheros con 3 adiciones y 3 borrados
  1. 3 3
      hivemind/optim/experimental/state_averager.py

+ 3 - 3
hivemind/optim/experimental/state_averager.py

@@ -155,11 +155,11 @@ class TrainingStateAverager(DecentralizedAverager):
 
     def _make_averaged_parameters(self, main_parameters: Sequence[torch.Tensor]):
         """Initialize averaged parameters based on the optimizer and averaging mode"""
-        return tuple(self._make_host_tensor(p.cpu() if self.offload_optimizer else p) for p in self.main_parameters)
+        return tuple(self._make_host_tensor(param, force_copy=True) for param in main_parameters)
 
-    def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
+    def _make_host_tensor(self, source_tensor: torch.Tensor, force_copy: bool = False) -> torch.Tensor:
         """Create a new tensor for averaging or reuse the existing one"""
-        if self.reuse_tensors:
+        if self.reuse_tensors and not force_copy:
             if source_tensor.device != torch.device("cpu"):
                 raise ValueError("reuse_tensors is only supported if all averaged tensors are on CPU.")
             if not source_tensor.is_shared():