Quellcode durchsuchen

Start converting averager to libp2p backend

Aleksandr Borzunov vor 4 Jahren
Ursprung
Commit
ea3b56d479
3 geänderte Dateien mit 17 neuen und 84 gelöschten Zeilen
  1. 15 72
      hivemind/averaging/averager.py
  2. 1 1
      hivemind/dht/__init__.py
  3. 1 11
      tests/test_averaging.py

+ 15 - 72
hivemind/averaging/averager.py

@@ -8,17 +8,13 @@ import ctypes
 import multiprocessing as mp
 import os
 import threading
-import uuid
 import weakref
 from concurrent.futures.thread import ThreadPoolExecutor
 from dataclasses import asdict
-from ipaddress import ip_address
 from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 
-import grpc
 import numpy as np
 import torch
-from grpc._cython.cygrpc import InternalError
 
 from hivemind.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
 from hivemind.averaging.group_info import GroupInfo
@@ -26,24 +22,22 @@ from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.dht import DHT, DHTID
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
-from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescriptor
+from hivemind.p2p import P2PContext, P2PHandlerError, PeerID as Endpoint, ServicerBase
+from hivemind.proto import averaging_pb2, runtime_pb2
+from hivemind.utils import MPFuture, get_logger, TensorDescriptor
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, split_for_streaming, combine_from_streaming
-from hivemind.utils.networking import choose_ip_address, strip_port, Hostname
+from hivemind.utils.grpc import split_for_streaming, combine_from_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
 
 # flavour types
-StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 GatheredData = Any
 logger = get_logger(__name__)
 
 
-class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
+class DecentralizedAverager(mp.Process, ServicerBase):
     """
-
     Parameter averaging service. A trainer can run this service in background to periodically average his parameters
     with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
     group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
@@ -69,12 +63,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
           If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
     :param client_mode: if False (default), this averager will accept incoming requests from other peers
             if True, the averager will only join existing groups where at least one peer has client_mode=False
-    :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
-    :param announced_host: visible IP address the averager will announce for external connections from other peers.
-          If None, the address will be chosen from p2p.get_visible_maddrs() (global IPv4 addresses are preferred)
-    :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
-          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
-    :param kwargs: extra parameters forwarded to grpc.aio.server
     :param auxiliary: if this flag is specified, averager.step will only assist others without sending
           local tensors for averaging
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
@@ -96,7 +84,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
-    _server: grpc.aio.Server
     serializer = MSGPackSerializer
 
     def __init__(
@@ -120,12 +107,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         auxiliary: bool = False,
         allow_state_sharing: Optional[bool] = None,
         client_mode: bool = False,
-        listen_on: Endpoint = "0.0.0.0:*",
         daemon: bool = True,
-        announced_host: Optional[str] = None,
-        channel_options: Sequence[Tuple[str, Any]] = (),
         shutdown_timeout: float = 5,
-        **kwargs,
     ):
         assert "." not in prefix, "group prefix must be a string without trailing '.'"
         assert bandwidth is None or (
@@ -138,7 +121,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
         super().__init__()
         self.dht = dht
-        self.client_mode, self.listen_on, self.kwargs = client_mode, listen_on, kwargs
+        self.p2p = dht.p2p
+        self.client_mode = client_mode
         self._parent_pid = os.getpid()
         if self.client_mode:
             self.mode = AveragingMode.CLIENT
@@ -146,11 +130,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             self.mode = AveragingMode.AUX
         else:
             self.mode = AveragingMode.NODE
-
-        if announced_host is None:
-            announced_host = self._choose_announced_host()
-        self.announced_host = announced_host
-        self.channel_options = channel_options
         self.daemon = daemon
 
         self._averaged_tensors = tuple(averaged_tensors)
@@ -179,17 +158,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
-        self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
 
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
 
-        self._averager_endpoint: Optional[Endpoint] = None
-        if self.client_mode:
-            self._averager_endpoint = f"client::{uuid.uuid4()}"
-
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         background_fetcher = threading.Thread(
@@ -201,22 +175,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         if start:
             self.run_in_background(await_ready=True)
 
-    def _choose_announced_host(self) -> Hostname:
-        announced_host = strip_port(self.listen_on).strip("[]")  # Stripping square brackets for IPv6
-        if ip_address(announced_host) not in [ip_address("0.0.0.0"), ip_address("::")]:
-            return announced_host
-
-        maddrs = self.dht.get_visible_maddrs()
-        announced_host = choose_ip_address(maddrs)
-        logger.info(
-            f"Choosing IP {announced_host} as endpoint for DecentralizedAverager " f"from visible multiaddrs {maddrs}"
-        )
-        return announced_host
-
-    @property
-    def port(self) -> Optional[Port]:
-        return self._port.value if self._port.value != 0 else None
-
     @property
     def allow_state_sharing(self) -> bool:
         """if set to True, other peers can download this peer's state"""
@@ -230,12 +188,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             self._allow_state_sharing.value = value
 
     @property
-    def endpoint(self) -> Optional[Endpoint]:
-        if self._averager_endpoint is None and not self.client_mode:
-            assert self.port is not None, "Averager is not running yet"
-            self._averager_endpoint = f"{self.announced_host}:{self.port}"
-            logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
-        return self._averager_endpoint
+    def endpoint(self) -> Endpoint:
+        return self.p2p.id
 
     def __repr__(self):
         return f"{self.__class__.__name__}({self.endpoint})"
@@ -257,15 +211,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
             async def _run():
-                grpc.aio.init_grpc_aio()
-
                 if not self.client_mode:
-                    self._server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-                    averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, self._server)
-                    found_port = self._server.add_insecure_port(self.listen_on)
-                    assert found_port != 0, f"Failed to listen to {self.listen_on}"
-                    self._port.value = found_port
-                    await self._server.start()
+                    await self.add_p2p_handlers(self.p2p)
                 else:
                     logger.debug(f"The averager is running in client mode.")
 
@@ -394,11 +341,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     MatchmakingException,
                     AssertionError,
                     StopAsyncIteration,
-                    InternalError,
                     asyncio.CancelledError,
                     asyncio.InvalidStateError,
-                    grpc.RpcError,
-                    grpc.aio.AioRpcError,
+                    P2PHandlerError,
                 ) as e:
                     time_elapsed = get_dht_time() - start_time
                     if not allow_retries or (timeout is not None and timeout < time_elapsed):
@@ -496,14 +441,14 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             self.lock_averaged_tensors.release()
 
     async def rpc_join_group(
-        self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
+        self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
         async for response in self._matchmaking.rpc_join_group(request, context):
             yield response
 
     async def rpc_aggregate_part(
-        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
+        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a groupmate sends us a part of his tensor; we should average it with other peers and return the result"""
         request = await anext(stream)
@@ -539,7 +484,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             await asyncio.sleep(self._matchmaking.averaging_expiration)
 
     async def rpc_download_state(
-        self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
+        self, request: averaging_pb2.DownloadRequest, _: P2PContext
     ) -> AsyncIterator[averaging_pb2.DownloadData]:
         """
         Get the up-to-date trainer state from a peer.
@@ -610,9 +555,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     logger.info(f"Downloading parameters from peer {peer}")
                     stream = None
                     try:
-                        stub = ChannelCache.get_stub(
-                            peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True, options=self.channel_options
-                        )
+                        stub = self.get_stub(self.p2p, peer)
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
                         async for message in stream:

+ 1 - 1
hivemind/dht/__init__.py

@@ -25,7 +25,7 @@ from multiaddr import Multiaddr
 
 from hivemind.dht.node import DHTNode
 from hivemind.p2p import P2P
-from hivemind.dht.routing import DHTKey, DHTValue, Subkey
+from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
 

+ 1 - 11
tests/test_averaging.py

@@ -77,7 +77,6 @@ def _test_allreduce_once(n_clients, n_aux):
             averaging_expiration=15,
             prefix="mygroup",
             client_mode=mode == AveragingMode.CLIENT,
-            listen_on="127.0.0.1:*",
             auxiliary=mode == AveragingMode.AUX,
             start=True,
         )
@@ -136,7 +135,6 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
             averaging_expiration=15,
             prefix="mygroup",
             client_mode=client_mode,
-            listen_on="127.0.0.1:*",
             start=True,
         )
         for tensors, client_mode in zip([tensors1, tensors2, tensors3, tensors4], client_modes)
@@ -191,7 +189,6 @@ def test_allreduce_compression():
             compression_type=compression_type_pair,
             target_group_size=2,
             prefix="mygroup",
-            listen_on="127.0.0.1:*",
             start=True,
         )
 
@@ -239,7 +236,6 @@ def test_allreduce_grid():
             target_group_size=2,
             prefix="mygroup",
             initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
-            listen_on="127.0.0.1:*",
             start=True,
         )
         for i in range(8)
@@ -278,7 +274,6 @@ def test_allgather():
             averaging_expiration=15,
             prefix="mygroup",
             initial_group_bits="000",
-            listen_on="127.0.0.1:*",
             start=True,
         )
         for _ in range(8)
@@ -361,7 +356,6 @@ def test_too_few_peers():
             request_timeout=0.5,
             prefix="mygroup",
             initial_group_bits=bin(i)[2:].rjust(3, "0"),
-            listen_on="127.0.0.1:*",
             start=True,
         )
         for i in range(4)
@@ -387,7 +381,6 @@ def test_overcrowded(num_peers=16):
             request_timeout=0.5,
             prefix="mygroup",
             initial_group_bits="",
-            listen_on="127.0.0.1:*",
             start=True,
         )
         for _ in range(num_peers)
@@ -426,7 +419,6 @@ def test_load_state_from_peers():
         start=True,
         prefix="demo-run",
         target_group_size=2,
-        listen_on="127.0.0.1:*",
     )
 
     dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
@@ -437,7 +429,6 @@ def test_load_state_from_peers():
         start=True,
         prefix="demo-run",
         target_group_size=2,
-        listen_on="127.0.0.1:*",
     )
 
     assert num_calls == 0
@@ -468,7 +459,7 @@ def test_load_state_from_peers():
 def test_getset_bits():
     dht = hivemind.DHT(start=True)
     averager = hivemind.averaging.DecentralizedAverager(
-        [torch.randn(3)], dht=dht, start=True, prefix="test_prefix", target_group_size=2, listen_on="127.0.0.1:*"
+        [torch.randn(3)], dht=dht, start=True, prefix="test_prefix", target_group_size=2,
     )
     averager.set_group_bits("00101011101010")
     assert averager.get_group_bits() == "00101011101010"
@@ -482,7 +473,6 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     common_kwargs = {
         "dht": dht,
         "start": True,
-        "listen_on": "127.0.0.1:*",
         "prefix": "demo-run",
         "target_group_size": 2,
     }