Ver código fonte

Implement servicer namespaces

Aleksandr Borzunov 4 anos atrás
pai
commit
182529f668

+ 4 - 1
hivemind/averaging/allreduce.py

@@ -34,6 +34,7 @@ class AllReduceRunner(ServicerBase):
     :param servicer: a hivemind.p2p.ServicerBase instance whose RPC signatures are used when requesting other peers.
       Typically, it is a DecentralizedAverager instance or its derivative.
       If None, uses ``self`` for this purpose (since this class may be a servicer itself for testing purposes).
+    :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
     :param group_id: unique identifier of this specific all-reduce run
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
@@ -52,6 +53,7 @@ class AllReduceRunner(ServicerBase):
         *,
         p2p: P2P,
         servicer: Optional[Union[ServicerBase, Type[ServicerBase]]],
+        prefix: Optional[str],
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
         ordered_group_endpoints: Sequence[Endpoint],
@@ -68,6 +70,7 @@ class AllReduceRunner(ServicerBase):
         if servicer is None:
             servicer = self
         self._servicer = servicer
+        self._prefix = prefix
 
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
@@ -111,7 +114,7 @@ class AllReduceRunner(ServicerBase):
         return len(self.ordered_group_endpoints)
 
     def _get_peer_stub(self, peer: Endpoint) -> StubBase:
-        return self._servicer.get_stub(self._p2p, peer)
+        return self._servicer.get_stub(self._p2p, peer, namespace=self._prefix)
 
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""

+ 4 - 2
hivemind/averaging/averager.py

@@ -122,6 +122,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
         super().__init__()
         self.dht = dht
+        self.prefix = prefix
 
         if client_mode is None:
             client_mode = dht.client_mode
@@ -214,7 +215,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             async def _run():
                 self._p2p = await self.dht.replicate_p2p()
                 if not self.client_mode:
-                    await self.add_p2p_handlers(self._p2p)
+                    await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
                 else:
                     logger.debug(f"The averager is running in client mode.")
 
@@ -389,6 +390,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 allreduce = AllReduceRunner(
                     p2p=self._p2p,
                     servicer=self,
+                    prefix=self.prefix,
                     group_id=group_info.group_id,
                     tensors=local_tensors,
                     ordered_group_endpoints=group_info.endpoints,
@@ -560,7 +562,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 if peer != self.endpoint:
                     logger.info(f"Downloading parameters from peer {peer}")
                     try:
-                        stub = self.get_stub(self._p2p, peer)
+                        stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
                         async for message in stream:

+ 1 - 1
hivemind/averaging/matchmaking.py

@@ -174,7 +174,7 @@ class Matchmaking:
         stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
         try:
             async with self.lock_request_join_group:
-                leader_stub = self._servicer.get_stub(self._p2p, leader)
+                leader_stub = self._servicer.get_stub(self._p2p, leader, namespace=self._servicer.prefix)
 
                 stream = leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(

+ 13 - 9
hivemind/p2p/servicer.py

@@ -24,9 +24,10 @@ class StubBase:
     adding the necessary rpc_* methods. Calls to these methods are translated to calls to the remote peer.
     """
 
-    def __init__(self, p2p: P2P, peer: PeerID):
+    def __init__(self, p2p: P2P, peer: PeerID, namespace: Optional[str]):
         self._p2p = p2p
         self._peer = peer
+        self._namespace = namespace
 
 
 class ServicerBase:
@@ -97,7 +98,7 @@ class ServicerBase:
 
                 return self._p2p.iterate_protobuf_handler(
                     self._peer,
-                    cls._get_handle_name(handler.method_name),
+                    cls._get_handle_name(self._namespace, handler.method_name),
                     input,
                     handler.response_type,
                 )
@@ -110,7 +111,7 @@ class ServicerBase:
                 return await asyncio.wait_for(
                     self._p2p.call_protobuf_handler(
                         self._peer,
-                        cls._get_handle_name(handler.method_name),
+                        cls._get_handle_name(self._namespace, handler.method_name),
                         input,
                         handler.response_type,
                     ),
@@ -120,26 +121,29 @@ class ServicerBase:
         caller.__name__ = handler.method_name
         return caller
 
-    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None) -> None:
+    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None) -> None:
         self._collect_rpc_handlers()
 
         servicer = self if wrapper is None else wrapper
         for handler in self._rpc_handlers:
             await p2p.add_protobuf_handler(
-                self._get_handle_name(handler.method_name),
+                self._get_handle_name(namespace, handler.method_name),
                 getattr(servicer, handler.method_name),
                 handler.request_type,
                 stream_input=handler.stream_input,
             )
 
     @classmethod
-    def get_stub(cls, p2p: P2P, peer: PeerID) -> StubBase:
+    def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:
         cls._collect_rpc_handlers()
-        return cls._stub_type(p2p, peer)
+        return cls._stub_type(p2p, peer, namespace)
 
     @classmethod
-    def _get_handle_name(cls, method_name: str) -> str:
-        return f"{cls.__name__}.{method_name}"
+    def _get_handle_name(cls, namespace: Optional[str], method_name: str) -> str:
+        handle_name = f"{cls.__name__}.{method_name}"
+        if namespace is not None:
+            handle_name = f"{namespace}::{handle_name}"
+        return handle_name
 
     @staticmethod
     def _strip_iterator_hint(hint: type) -> Tuple[type, bool]:

+ 1 - 0
tests/test_allreduce.py

@@ -193,6 +193,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
         allreduce_protocol = AllReduceRunner(
             p2p=p2p,
             servicer=AllReduceRunner,
+            prefix=None,
             group_id=group_id,
             tensors=[x.clone() for x in tensors_by_peer[p2p.id]],
             ordered_group_endpoints=peers,