Browse Source

extract all-reduce inplace func

justheuristic 3 years ago
parent
commit
abd74b71b9
1 changed files with 31 additions and 20 deletions
  1. 31 20
      hivemind/optim/power_sgd_averager.py

+ 31 - 20
hivemind/optim/power_sgd_averager.py

@@ -159,27 +159,10 @@ class PowerSGDGradientAverager(GradientAverager):
                     # we use reshape for all matrixes because PowerSGD works only with 2d tensors
                     torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
 
-                allreduce_phase_p = AllReduceRunner(
-                    p2p=self._p2p,
-                    servicer_type=type(self),
-                    prefix=self.prefix,
-                    group_id=group_info.group_id + AllReducePhases.PHASE_P.name.encode(),
-                    tensors=ps,
-                    ordered_peer_ids=group_info.peer_ids,
-                    peer_fractions=peer_fractions,
-                    gathered=user_gathered,
-                    modes=modes,
-                    **kwargs,
-                )
-                self._running_groups[group_info.group_id + AllReducePhases.PHASE_P.name.encode()].set_result(allreduce_phase_p)
+                p_group_id = group_info.group_id + AllReducePhases.PHASE_P.name.encode()
+                q_group_id = group_info.group_id + AllReducePhases.PHASE_Q.name.encode()
 
-                if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                    async for tensor, update in azip(as_aiter(*ps), allreduce_phase_p):
-                        # all-reduce is performed asynchronously while iterating
-                        tensor.add_(update, alpha=self._averaging_alpha)
-                else:
-                    async for _ in allreduce_phase_p:  # trigger all-reduce by iterating
-                        raise ValueError("aux peers should not receive averaged tensors")
+                self._run_allreduce_inplace_(ps, group_info, p_group_id, peer_fractions=peer_fractions, **kwargs)
 
                 for p in ps:
                     orthogonalize_(p)
@@ -224,6 +207,34 @@ class PowerSGDGradientAverager(GradientAverager):
             logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
 
+    def _run_allreduce_inplace_(self,
+                                tensors: Sequence[torch.Tensor],
+                                group_info: GroupInfo,
+                                group_id: Optional[bytes] = None,
+                                **kwargs):
+        group_id = group_info.group_id if group_id is None else group_id
+
+        runner = AllReduceRunner(
+            p2p=self._p2p,
+            servicer_type=type(self),
+            prefix=self.prefix,
+            tensors=tensors,
+            group_id=group_id,
+            ordered_peer_ids=group_info.peer_ids,
+            **kwargs,
+        )
+        assert group_id in self._running_groups, f"Group id {group_id} was not registered in _register_allreduce_group"
+        self._running_groups[group_id].set_result(runner)
+
+        if runner.modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+            async for tensor, update in azip(as_aiter(*tensors), runner):
+                tensor.add_(update, alpha=self._averaging_alpha)
+                self.last_updated = get_dht_time()
+                self._state_updated.set()
+        else:
+            async for _ in runner:
+                raise ValueError("aux peers should not receive averaged tensors")
+
     def get_current_state(self):
         """
         Get current gradient averager state and when requested by a newbie peer.