|
@@ -159,27 +159,10 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
|
# we use reshape for all matrixes because PowerSGD works only with 2d tensors
|
|
|
torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
|
|
|
|
|
|
- allreduce_phase_p = AllReduceRunner(
|
|
|
- p2p=self._p2p,
|
|
|
- servicer_type=type(self),
|
|
|
- prefix=self.prefix,
|
|
|
- group_id=group_info.group_id + AllReducePhases.PHASE_P.name.encode(),
|
|
|
- tensors=ps,
|
|
|
- ordered_peer_ids=group_info.peer_ids,
|
|
|
- peer_fractions=peer_fractions,
|
|
|
- gathered=user_gathered,
|
|
|
- modes=modes,
|
|
|
- **kwargs,
|
|
|
- )
|
|
|
- self._running_groups[group_info.group_id + AllReducePhases.PHASE_P.name.encode()].set_result(allreduce_phase_p)
|
|
|
+ p_group_id = group_info.group_id + AllReducePhases.PHASE_P.name.encode()
|
|
|
+ q_group_id = group_info.group_id + AllReducePhases.PHASE_Q.name.encode()
|
|
|
|
|
|
- if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
- async for tensor, update in azip(as_aiter(*ps), allreduce_phase_p):
|
|
|
- # all-reduce is performed asynchronously while iterating
|
|
|
- tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
- else:
|
|
|
- async for _ in allreduce_phase_p: # trigger all-reduce by iterating
|
|
|
- raise ValueError("aux peers should not receive averaged tensors")
|
|
|
+ self._run_allreduce_inplace_(ps, group_info, p_group_id, peer_fractions=peer_fractions, **kwargs)
|
|
|
|
|
|
for p in ps:
|
|
|
orthogonalize_(p)
|
|
@@ -224,6 +207,34 @@ class PowerSGDGradientAverager(GradientAverager):
|
|
|
logger.exception(e)
|
|
|
raise MatchmakingException(f"Unable to run All-Reduce: {e}")
|
|
|
|
|
|
+ def _run_allreduce_inplace_(self,
|
|
|
+ tensors: Sequence[torch.Tensor],
|
|
|
+ group_info: GroupInfo,
|
|
|
+ group_id: Optional[bytes] = None,
|
|
|
+ **kwargs):
|
|
|
+ group_id = group_info.group_id if group_id is None else group_id
|
|
|
+
|
|
|
+ runner = AllReduceRunner(
|
|
|
+ p2p=self._p2p,
|
|
|
+ servicer_type=type(self),
|
|
|
+ prefix=self.prefix,
|
|
|
+ tensors=tensors,
|
|
|
+ group_id=group_id,
|
|
|
+ ordered_peer_ids=group_info.peer_ids,
|
|
|
+ **kwargs,
|
|
|
+ )
|
|
|
+ assert group_id in self._running_groups, f"Group id {group_id} was not registered in _register_allreduce_group"
|
|
|
+ self._running_groups[group_id].set_result(runner)
|
|
|
+
|
|
|
+ if runner.modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
+ async for tensor, update in azip(as_aiter(*tensors), runner):
|
|
|
+ tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
+ self.last_updated = get_dht_time()
|
|
|
+ self._state_updated.set()
|
|
|
+ else:
|
|
|
+ async for _ in runner:
|
|
|
+ raise ValueError("aux peers should not receive averaged tensors")
|
|
|
+
|
|
|
def get_current_state(self):
|
|
|
"""
|
|
|
Get current gradient averager state and when requested by a newbie peer.
|