Forráskód Böngészése

Fix RPC in ServicerBase derivatives for test_training_averager

Aleksandr Borzunov 4 éve
szülő
commit
20f19b122f

+ 10 - 3
hivemind/averaging/allreduce.py

@@ -30,6 +30,10 @@ class AllReduceRunner(ServicerBase):
     creating a full DecentralizedAverager.
 
     :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
+    :param p2p: a hivemind.p2p.P2P instance used for communication with other peers
+    :param servicer: a hivemind.p2p.ServicerBase instance whose RPC signatures are used when requesting other peers.
+      Typically, it is a DecentralizedAverager instance or its derivative.
+      If None, uses ``self`` for this purpose (since this class may be a servicer itself for testing purposes).
     :param group_id: unique identifier of this specific all-reduce run
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
@@ -47,6 +51,7 @@ class AllReduceRunner(ServicerBase):
         self,
         *,
         p2p: P2P,
+        servicer: Optional[ServicerBase],
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
         ordered_group_endpoints: Sequence[Endpoint],
@@ -60,6 +65,10 @@ class AllReduceRunner(ServicerBase):
         self.endpoint = p2p.id
         assert self.endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
 
+        if servicer is None:
+            servicer = self
+        self._servicer = servicer
+
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
         assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
@@ -102,9 +111,7 @@ class AllReduceRunner(ServicerBase):
         return len(self.ordered_group_endpoints)
 
     def _get_stub(self, peer: Endpoint) -> StubBase:
-        from hivemind.averaging.averager import DecentralizedAverager
-
-        return DecentralizedAverager.get_stub(self._p2p, peer)
+        return self._servicer.get_stub(self._p2p, peer)
 
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""

+ 7 - 1
hivemind/averaging/averager.py

@@ -214,7 +214,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     logger.debug(f"The averager is running in client mode.")
 
                 self._matchmaking = Matchmaking(
-                    self._p2p, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
+                    self._p2p,
+                    self,
+                    self.schema_hash,
+                    self.dht,
+                    client_mode=self.client_mode,
+                    **self.matchmaking_kwargs,
                 )
                 if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
@@ -378,6 +383,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             async with self.get_tensors_async() as local_tensors:
                 allreduce = AllReduceRunner(
                     p2p=self._p2p,
+                    servicer=self,
                     group_id=group_info.group_id,
                     tensors=local_tensors,
                     ordered_group_endpoints=group_info.endpoints,

+ 4 - 4
hivemind/averaging/matchmaking.py

@@ -13,7 +13,7 @@ import asyncio
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.dht import DHT, DHTID, DHTExpiration
-from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID as Endpoint
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID as Endpoint, ServicerBase
 from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time
 from hivemind.utils.asyncio import anext
 from hivemind.proto import averaging_pb2
@@ -37,6 +37,7 @@ class Matchmaking:
     def __init__(
         self,
         p2p: P2P,
+        servicer: ServicerBase,
         schema_hash: bytes,
         dht: DHT,
         *,
@@ -57,6 +58,7 @@ class Matchmaking:
 
         super().__init__()
         self._p2p = p2p
+        self._servicer = servicer
         self.endpoint = p2p.id
         self.schema_hash = schema_hash
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
@@ -173,9 +175,7 @@ class Matchmaking:
         stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
         try:
             async with self.lock_request_join_group:
-                from hivemind.averaging.averager import DecentralizedAverager
-
-                leader_stub = DecentralizedAverager.get_stub(self._p2p, leader)
+                leader_stub = self._servicer.get_stub(self._p2p, leader)
 
                 stream = leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(

+ 1 - 1
hivemind/dht/__init__.py

@@ -24,9 +24,9 @@ from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, Type
 from multiaddr import Multiaddr
 
 from hivemind.dht.node import DHTNode
-from hivemind.p2p import P2P, PeerID
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
+from hivemind.p2p import P2P, PeerID
 from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
 
 logger = get_logger(__name__)

+ 2 - 5
tests/test_allreduce.py

@@ -176,10 +176,6 @@ NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
 async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
     """Run group allreduce protocol manually without grpc, see if the internal logic is working as intended"""
 
-    class AllreduceRunnerForTesting(AllReduceRunner):
-        def _get_stub(self, peer: str) -> StubBase:
-            return AllreduceRunnerForTesting.get_stub(self._p2p, peer)
-
     p2ps = [await P2P.create()]
     visible_maddrs = await p2ps[0].get_visible_maddrs()
     p2ps += await asyncio.gather(*[P2P.create(initial_peers=visible_maddrs) for _ in range(3)])
@@ -194,8 +190,9 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
 
     allreduce_protocols = []
     for p2p in p2ps:
-        allreduce_protocol = AllreduceRunnerForTesting(
+        allreduce_protocol = AllReduceRunner(
             p2p=p2p,
+            servicer=AllReduceRunner,
             group_id=group_id,
             tensors=[x.clone() for x in tensors_by_peer[p2p.id]],
             ordered_group_endpoints=peers,