|
@@ -65,18 +65,10 @@ class PowerSGDGradientAverager(PowerEFGradientAverager):
|
|
|
self._local_ef = averager_local_ef
|
|
|
self._uncompressed_gradients = set(i for i, grad in enumerate(self._grads_from_parameters()) if len(tuple(grad.size())) == 1)
|
|
|
self._ms = list(
|
|
|
- torch.zeros_like(grad, device=accumulate_grads_on)
|
|
|
+ torch.zeros_like(grad, device="cpu")
|
|
|
for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
|
|
|
)
|
|
|
- self._gs = list(
|
|
|
- torch.zeros_like(grad, device=accumulate_grads_on)
|
|
|
- for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
|
|
|
- )
|
|
|
- self._qs = list(
|
|
|
- torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device=accumulate_grads_on)
|
|
|
- for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
|
|
|
- )
|
|
|
- for tensor in (self._qs + self._gs):
|
|
|
+ for tensor in (self._ms):
|
|
|
if tensor is not None:
|
|
|
assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
|
|
|
tensor.share_memory_()
|