|
@@ -30,6 +30,10 @@ class AllReduceRunner(ServicerBase):
|
|
|
creating a full DecentralizedAverager.
|
|
|
|
|
|
:note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
|
|
|
+ :param p2p: a hivemind.p2p.P2P instance used for communication with other peers
|
|
|
+ :param servicer: a hivemind.p2p.ServicerBase instance whose RPC signatures are used when requesting other peers.
|
|
|
+ Typically, it is a DecentralizedAverager instance or its derivative.
|
|
|
+ If None, uses ``self`` for this purpose (since this class may be a servicer itself for testing purposes).
|
|
|
:param group_id: unique identifier of this specific all-reduce run
|
|
|
:param tensors: local tensors that should be averaged with groupmates
|
|
|
:param tensors: local tensors that should be averaged with groupmates
|
|
@@ -47,6 +51,7 @@ class AllReduceRunner(ServicerBase):
|
|
|
self,
|
|
|
*,
|
|
|
p2p: P2P,
|
|
|
+ servicer: Optional[ServicerBase],
|
|
|
group_id: GroupID,
|
|
|
tensors: Sequence[torch.Tensor],
|
|
|
ordered_group_endpoints: Sequence[Endpoint],
|
|
@@ -60,6 +65,10 @@ class AllReduceRunner(ServicerBase):
|
|
|
self.endpoint = p2p.id
|
|
|
assert self.endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
|
|
|
|
|
|
+ if servicer is None:
|
|
|
+ servicer = self
|
|
|
+ self._servicer = servicer
|
|
|
+
|
|
|
modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
|
|
|
weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
|
|
|
assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
|
|
@@ -102,9 +111,7 @@ class AllReduceRunner(ServicerBase):
|
|
|
return len(self.ordered_group_endpoints)
|
|
|
|
|
|
def _get_stub(self, peer: Endpoint) -> StubBase:
|
|
|
- from hivemind.averaging.averager import DecentralizedAverager
|
|
|
-
|
|
|
- return DecentralizedAverager.get_stub(self._p2p, peer)
|
|
|
+ return self._servicer.get_stub(self._p2p, peer)
|
|
|
|
|
|
async def run(self) -> AsyncIterator[torch.Tensor]:
|
|
|
"""Run all-reduce, return differences between averaged and original tensors as they are computed"""
|