소스 검색

Fix RPC in ServicerBase derivatives for test_training_averager

Aleksandr Borzunov 4 년 전
부모
커밋
20f19b122f
5개의 변경된 파일24개의 추가작업 그리고 14개의 파일을 삭제
  1. 10 3
      hivemind/averaging/allreduce.py
  2. 7 1
      hivemind/averaging/averager.py
  3. 4 4
      hivemind/averaging/matchmaking.py
  4. 1 1
      hivemind/dht/__init__.py
  5. 2 5
      tests/test_allreduce.py

+ 10 - 3
hivemind/averaging/allreduce.py

@@ -30,6 +30,10 @@ class AllReduceRunner(ServicerBase):
     creating a full DecentralizedAverager.
     creating a full DecentralizedAverager.
 
 
     :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
     :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
+    :param p2p: a hivemind.p2p.P2P instance used for communication with other peers
+    :param servicer: a hivemind.p2p.ServicerBase instance whose RPC signatures are used when requesting other peers.
+      Typically, it is a DecentralizedAverager instance or its derivative.
+      If None, uses ``self`` for this purpose (since this class may be a servicer itself for testing purposes).
     :param group_id: unique identifier of this specific all-reduce run
     :param group_id: unique identifier of this specific all-reduce run
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
@@ -47,6 +51,7 @@ class AllReduceRunner(ServicerBase):
         self,
         self,
         *,
         *,
         p2p: P2P,
         p2p: P2P,
+        servicer: Optional[ServicerBase],
         group_id: GroupID,
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
         tensors: Sequence[torch.Tensor],
         ordered_group_endpoints: Sequence[Endpoint],
         ordered_group_endpoints: Sequence[Endpoint],
@@ -60,6 +65,10 @@ class AllReduceRunner(ServicerBase):
         self.endpoint = p2p.id
         self.endpoint = p2p.id
         assert self.endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         assert self.endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
 
 
+        if servicer is None:
+            servicer = self
+        self._servicer = servicer
+
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
         assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
         assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
@@ -102,9 +111,7 @@ class AllReduceRunner(ServicerBase):
         return len(self.ordered_group_endpoints)
         return len(self.ordered_group_endpoints)
 
 
     def _get_stub(self, peer: Endpoint) -> StubBase:
     def _get_stub(self, peer: Endpoint) -> StubBase:
-        from hivemind.averaging.averager import DecentralizedAverager
-
-        return DecentralizedAverager.get_stub(self._p2p, peer)
+        return self._servicer.get_stub(self._p2p, peer)
 
 
     async def run(self) -> AsyncIterator[torch.Tensor]:
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""

+ 7 - 1
hivemind/averaging/averager.py

@@ -214,7 +214,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     logger.debug(f"The averager is running in client mode.")
                     logger.debug(f"The averager is running in client mode.")
 
 
                 self._matchmaking = Matchmaking(
                 self._matchmaking = Matchmaking(
-                    self._p2p, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
+                    self._p2p,
+                    self,
+                    self.schema_hash,
+                    self.dht,
+                    client_mode=self.client_mode,
+                    **self.matchmaking_kwargs,
                 )
                 )
                 if not self.client_mode:
                 if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
                     asyncio.create_task(self._declare_for_download_periodically())
@@ -378,6 +383,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             async with self.get_tensors_async() as local_tensors:
             async with self.get_tensors_async() as local_tensors:
                 allreduce = AllReduceRunner(
                 allreduce = AllReduceRunner(
                     p2p=self._p2p,
                     p2p=self._p2p,
+                    servicer=self,
                     group_id=group_info.group_id,
                     group_id=group_info.group_id,
                     tensors=local_tensors,
                     tensors=local_tensors,
                     ordered_group_endpoints=group_info.endpoints,
                     ordered_group_endpoints=group_info.endpoints,

+ 4 - 4
hivemind/averaging/matchmaking.py

@@ -13,7 +13,7 @@ import asyncio
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.dht import DHT, DHTID, DHTExpiration
-from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID as Endpoint
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID as Endpoint, ServicerBase
 from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time
 from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time
 from hivemind.utils.asyncio import anext
 from hivemind.utils.asyncio import anext
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
@@ -37,6 +37,7 @@ class Matchmaking:
     def __init__(
     def __init__(
         self,
         self,
         p2p: P2P,
         p2p: P2P,
+        servicer: ServicerBase,
         schema_hash: bytes,
         schema_hash: bytes,
         dht: DHT,
         dht: DHT,
         *,
         *,
@@ -57,6 +58,7 @@ class Matchmaking:
 
 
         super().__init__()
         super().__init__()
         self._p2p = p2p
         self._p2p = p2p
+        self._servicer = servicer
         self.endpoint = p2p.id
         self.endpoint = p2p.id
         self.schema_hash = schema_hash
         self.schema_hash = schema_hash
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
@@ -173,9 +175,7 @@ class Matchmaking:
         stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
         stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
         try:
         try:
             async with self.lock_request_join_group:
             async with self.lock_request_join_group:
-                from hivemind.averaging.averager import DecentralizedAverager
-
-                leader_stub = DecentralizedAverager.get_stub(self._p2p, leader)
+                leader_stub = self._servicer.get_stub(self._p2p, leader)
 
 
                 stream = leader_stub.rpc_join_group(
                 stream = leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
                     averaging_pb2.JoinRequest(

+ 1 - 1
hivemind/dht/__init__.py

@@ -24,9 +24,9 @@ from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, Type
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 from hivemind.dht.node import DHTNode
 from hivemind.dht.node import DHTNode
-from hivemind.p2p import P2P, PeerID
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
+from hivemind.p2p import P2P, PeerID
 from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
 from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 2 - 5
tests/test_allreduce.py

@@ -176,10 +176,6 @@ NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
 async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
 async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
     """Run group allreduce protocol manually without grpc, see if the internal logic is working as intended"""
     """Run group allreduce protocol manually without grpc, see if the internal logic is working as intended"""
 
 
-    class AllreduceRunnerForTesting(AllReduceRunner):
-        def _get_stub(self, peer: str) -> StubBase:
-            return AllreduceRunnerForTesting.get_stub(self._p2p, peer)
-
     p2ps = [await P2P.create()]
     p2ps = [await P2P.create()]
     visible_maddrs = await p2ps[0].get_visible_maddrs()
     visible_maddrs = await p2ps[0].get_visible_maddrs()
     p2ps += await asyncio.gather(*[P2P.create(initial_peers=visible_maddrs) for _ in range(3)])
     p2ps += await asyncio.gather(*[P2P.create(initial_peers=visible_maddrs) for _ in range(3)])
@@ -194,8 +190,9 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
 
 
     allreduce_protocols = []
     allreduce_protocols = []
     for p2p in p2ps:
     for p2p in p2ps:
-        allreduce_protocol = AllreduceRunnerForTesting(
+        allreduce_protocol = AllReduceRunner(
             p2p=p2p,
             p2p=p2p,
+            servicer=AllReduceRunner,
             group_id=group_id,
             group_id=group_id,
             tensors=[x.clone() for x in tensors_by_peer[p2p.id]],
             tensors=[x.clone() for x in tensors_by_peer[p2p.id]],
             ordered_group_endpoints=peers,
             ordered_group_endpoints=peers,