Browse Source

better naming

Artem Chumachenko 3 years ago
parent
commit
46d61e0a51
1 changed files with 7 additions and 8 deletions
  1. 7 8
      hivemind/optim/power_ef_averager.py

+ 7 - 8
hivemind/optim/power_ef_averager.py

@@ -125,13 +125,12 @@ class PowerEFGradientAverager(GradientAverager):
                 None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
             )
 
-            async with enter_asynchronously(self.get_tensors()) as local_tensors:
-                compressed_tensors = [
-                    lt.to("cpu") for idx, lt in enumerate(local_tensors) if idx not in self._uncompressed_gradients
-                ]
-
+            async with enter_asynchronously(self.get_tensors()) as averaged_grads:
                 cs = [rest for idx, rest in enumerate(self._gradient_rests) if idx not in self._uncompressed_gradients]
-                ps = [torch.zeros((grad.size(0), self.rank), device="cpu") for grad in compressed_tensors]
+                ps = [
+                    torch.zeros((grad.size(0), self.rank), device="cpu")
+                    for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients
+                ]
                 for p, q, rest in zip(ps, self._qs, cs):
                     torch.matmul(rest.reshape(-1, q.size(0)), q, out=p)
                 first_all_reduced = ps + [
@@ -196,8 +195,8 @@ class PowerEFGradientAverager(GradientAverager):
                     new_c = torch.matmul(p, q.t())
                     c.copy_(new_c.reshape(c.size()))
 
-                for rest, lt in zip(self._gradient_rests, local_tensors):
-                    torch.add(lt, rest, out=lt)
+                for rest, grad in zip(self._gradient_rests, averaged_grads):
+                    torch.add(grad, rest, out=grad)
 
                 return allreduce1.gathered
         except BaseException as e: