|
@@ -178,14 +178,11 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
|
|
|
|
|
|
class AllreduceRunnerForTesting(AllReduceRunner):
|
|
|
def _get_stub(self, peer: str) -> StubBase:
|
|
|
- return AllReduceRunner.get_stub(self._p2p, peer)
|
|
|
+ return AllreduceRunnerForTesting.get_stub(self._p2p, peer)
|
|
|
|
|
|
- p2ps = []
|
|
|
- initial_peers = []
|
|
|
- for _ in range(4):
|
|
|
- instance = await P2P.create(initial_peers=initial_peers)
|
|
|
- p2ps.append(instance)
|
|
|
- initial_peers.extend(await instance.get_visible_maddrs())
|
|
|
+ p2ps = [await P2P.create()]
|
|
|
+ visible_maddrs = await p2ps[0].get_visible_maddrs()
|
|
|
+ p2ps += await asyncio.gather(*[P2P.create(initial_peers=visible_maddrs) for _ in range(3)])
|
|
|
|
|
|
peers = [instance.id for instance in p2ps]
|
|
|
tensors_by_peer = {
|
|
@@ -196,11 +193,11 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
|
|
|
group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
|
|
|
|
|
|
allreduce_protocols = []
|
|
|
- for p2p, peer in zip(p2ps, peers):
|
|
|
+ for p2p in p2ps:
|
|
|
allreduce_protocol = AllreduceRunnerForTesting(
|
|
|
p2p=p2p,
|
|
|
group_id=group_id,
|
|
|
- tensors=[x.clone() for x in tensors_by_peer[peer]],
|
|
|
+ tensors=[x.clone() for x in tensors_by_peer[p2p.id]],
|
|
|
ordered_group_endpoints=peers,
|
|
|
peer_fractions=peer_fractions,
|
|
|
modes=peer_modes,
|