|
@@ -125,13 +125,12 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
|
|
|
)
|
|
|
|
|
|
- async with enter_asynchronously(self.get_tensors()) as local_tensors:
|
|
|
- compressed_tensors = [
|
|
|
- lt.to("cpu") for idx, lt in enumerate(local_tensors) if idx not in self._uncompressed_gradients
|
|
|
- ]
|
|
|
-
|
|
|
+ async with enter_asynchronously(self.get_tensors()) as averaged_grads:
|
|
|
cs = [rest for idx, rest in enumerate(self._gradient_rests) if idx not in self._uncompressed_gradients]
|
|
|
- ps = [torch.zeros((grad.size(0), self.rank), device="cpu") for grad in compressed_tensors]
|
|
|
+ ps = [
|
|
|
+ torch.zeros((grad.size(0), self.rank), device="cpu")
|
|
|
+ for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients
|
|
|
+ ]
|
|
|
for p, q, rest in zip(ps, self._qs, cs):
|
|
|
torch.matmul(rest.reshape(-1, q.size(0)), q, out=p)
|
|
|
first_all_reduced = ps + [
|
|
@@ -196,8 +195,8 @@ class PowerEFGradientAverager(GradientAverager):
|
|
|
new_c = torch.matmul(p, q.t())
|
|
|
c.copy_(new_c.reshape(c.size()))
|
|
|
|
|
|
- for rest, lt in zip(self._gradient_rests, local_tensors):
|
|
|
- torch.add(lt, rest, out=lt)
|
|
|
+ for rest, grad in zip(self._gradient_rests, averaged_grads):
|
|
|
+ torch.add(grad, rest, out=grad)
|
|
|
|
|
|
return allreduce1.gathered
|
|
|
except BaseException as e:
|