|
@@ -129,7 +129,8 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
cs = [rest for idx, rest in enumerate(self._gradient_rests) if idx not in self._uncompressed_gradients]
|
|
|
ps = [
|
|
|
torch.zeros((grad.size(0), self.rank), device="cpu")
|
|
|
- for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients
|
|
|
+ for idx, grad in enumerate(averaged_grads)
|
|
|
+ if idx not in self._uncompressed_gradients
|
|
|
]
|
|
|
for p, q, rest in zip(ps, self._qs, cs):
|
|
|
torch.matmul(rest.reshape(-1, q.size(0)), q, out=p)
|