|
@@ -110,7 +110,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
self._old_tensors: Optional[Sequence[torch.Tensor]] = None # for delta rule
|
|
|
|
|
|
self.main_parameters, self.parameter_names = main_parameters, parameter_names
|
|
|
- self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
|
|
|
+ self._averaged_parameters = tuple(self._make_host_tensor(p, force_copy=True) for p in self.main_parameters)
|
|
|
self.optimizer, self.scheduler = self._init_components(
|
|
|
param_groups, optimizer, scheduler, initialize_optimizer
|
|
|
)
|
|
@@ -153,9 +153,9 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
|
|
|
return param_groups, parameters, parameter_names
|
|
|
|
|
|
- def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
|
|
|
+ def _make_host_tensor(self, source_tensor: torch.Tensor, force_copy: bool = False) -> torch.Tensor:
|
|
|
"""Create a new tensor for averaging or reuse the existing one"""
|
|
|
- if self.reuse_tensors:
|
|
|
+ if self.reuse_tensors and not force_copy:
|
|
|
if source_tensor.device != torch.device("cpu"):
|
|
|
raise ValueError("reuse_tensors is only supported if all averaged tensors are on CPU.")
|
|
|
if not source_tensor.is_shared():
|