Sfoglia il codice sorgente

Fix test_allreduce.py

Aleksandr Borzunov 4 anni fa
parent
commit
83c5d30f4c

+ 4 - 1
hivemind/averaging/allreduce.py

@@ -24,7 +24,10 @@ class AveragingMode(Enum):
 
 class AllReduceRunner(ServicerBase):
     """
-    An internal class that runs butterfly AllReduce in a predefined group of averagers
+    An internal class that runs butterfly AllReduce in a predefined group of averagers.
+
+    This class inherits hivemind.p2p.ServicerBase, so it can be used as an RPCServicer for testing purposes without
+    creating a full DecentralizedAverager.
 
     :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
     :param group_id: unique identifier of this specific all-reduce run

+ 1 - 1
hivemind/averaging/partition.py

@@ -32,7 +32,7 @@ class TensorPartContainer:
         self,
         tensors: Sequence[torch.Tensor],
         peer_fractions: Sequence[float],
-        compression_type: Union['CompressionType', Sequence['CompressionType']] = CompressionType.NONE,
+        compression_type: Union["CompressionType", Sequence["CompressionType"]] = CompressionType.NONE,
         part_size_bytes: int = 2 ** 20,
         prefetch: int = 1,
     ):

+ 6 - 9
tests/test_allreduce.py

@@ -178,14 +178,11 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
 
     class AllreduceRunnerForTesting(AllReduceRunner):
         def _get_stub(self, peer: str) -> StubBase:
-            return AllReduceRunner.get_stub(self._p2p, peer)
+            return AllreduceRunnerForTesting.get_stub(self._p2p, peer)
 
-    p2ps = []
-    initial_peers = []
-    for _ in range(4):
-        instance = await P2P.create(initial_peers=initial_peers)
-        p2ps.append(instance)
-        initial_peers.extend(await instance.get_visible_maddrs())
+    p2ps = [await P2P.create()]
+    visible_maddrs = await p2ps[0].get_visible_maddrs()
+    p2ps += await asyncio.gather(*[P2P.create(initial_peers=visible_maddrs) for _ in range(3)])
 
     peers = [instance.id for instance in p2ps]
     tensors_by_peer = {
@@ -196,11 +193,11 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
 
     allreduce_protocols = []
-    for p2p, peer in zip(p2ps, peers):
+    for p2p in p2ps:
         allreduce_protocol = AllreduceRunnerForTesting(
             p2p=p2p,
             group_id=group_id,
-            tensors=[x.clone() for x in tensors_by_peer[peer]],
+            tensors=[x.clone() for x in tensors_by_peer[p2p.id]],
             ordered_group_endpoints=peers,
             peer_fractions=peer_fractions,
             modes=peer_modes,