|
@@ -8,17 +8,13 @@ import ctypes
|
|
|
import multiprocessing as mp
|
|
|
import os
|
|
|
import threading
|
|
|
-import uuid
|
|
|
import weakref
|
|
|
from concurrent.futures.thread import ThreadPoolExecutor
|
|
|
from dataclasses import asdict
|
|
|
-from ipaddress import ip_address
|
|
|
from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
|
|
|
|
|
|
-import grpc
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
-from grpc._cython.cygrpc import InternalError
|
|
|
|
|
|
from hivemind.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
|
|
|
from hivemind.averaging.group_info import GroupInfo
|
|
@@ -26,24 +22,22 @@ from hivemind.averaging.load_balancing import load_balance_peers
|
|
|
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
|
from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
|
|
|
from hivemind.dht import DHT, DHTID
|
|
|
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
|
|
|
-from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescriptor
|
|
|
+from hivemind.p2p import P2PContext, P2PHandlerError, PeerID as Endpoint, ServicerBase
|
|
|
+from hivemind.proto import averaging_pb2, runtime_pb2
|
|
|
+from hivemind.utils import MPFuture, get_logger, TensorDescriptor
|
|
|
from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
|
|
|
from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
|
|
|
-from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, split_for_streaming, combine_from_streaming
|
|
|
-from hivemind.utils.networking import choose_ip_address, strip_port, Hostname
|
|
|
+from hivemind.utils.grpc import split_for_streaming, combine_from_streaming
|
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
|
from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
|
|
|
|
|
|
# flavour types
|
|
|
-StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
|
|
|
GatheredData = Any
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
-class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
+class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
"""
|
|
|
-
|
|
|
Parameter averaging service. A trainer can run this service in background to periodically average his parameters
|
|
|
with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
|
|
|
group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
|
|
@@ -69,12 +63,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
|
|
|
:param client_mode: if False (default), this averager will accept incoming requests from other peers
|
|
|
if True, the averager will only join existing groups where at least one peer has client_mode=False
|
|
|
- :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
|
|
|
- :param announced_host: visible IP address the averager will announce for external connections from other peers.
|
|
|
- If None, the address will be chosen from p2p.get_visible_maddrs() (global IPv4 addresses are preferred)
|
|
|
- :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
|
|
|
- see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
|
|
|
- :param kwargs: extra parameters forwarded to grpc.aio.server
|
|
|
:param auxiliary: if this flag is specified, averager.step will only assist others without sending
|
|
|
local tensors for averaging
|
|
|
:param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
|
|
@@ -96,7 +84,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
|
|
|
_matchmaking: Matchmaking
|
|
|
_pending_group_assembled: asyncio.Event
|
|
|
- _server: grpc.aio.Server
|
|
|
serializer = MSGPackSerializer
|
|
|
|
|
|
def __init__(
|
|
@@ -120,12 +107,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
auxiliary: bool = False,
|
|
|
allow_state_sharing: Optional[bool] = None,
|
|
|
client_mode: bool = False,
|
|
|
- listen_on: Endpoint = "0.0.0.0:*",
|
|
|
daemon: bool = True,
|
|
|
- announced_host: Optional[str] = None,
|
|
|
- channel_options: Sequence[Tuple[str, Any]] = (),
|
|
|
shutdown_timeout: float = 5,
|
|
|
- **kwargs,
|
|
|
):
|
|
|
assert "." not in prefix, "group prefix must be a string without trailing '.'"
|
|
|
assert bandwidth is None or (
|
|
@@ -138,7 +121,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
|
|
|
super().__init__()
|
|
|
self.dht = dht
|
|
|
- self.client_mode, self.listen_on, self.kwargs = client_mode, listen_on, kwargs
|
|
|
+ self.p2p = dht.p2p
|
|
|
+ self.client_mode = client_mode
|
|
|
self._parent_pid = os.getpid()
|
|
|
if self.client_mode:
|
|
|
self.mode = AveragingMode.CLIENT
|
|
@@ -146,11 +130,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
self.mode = AveragingMode.AUX
|
|
|
else:
|
|
|
self.mode = AveragingMode.NODE
|
|
|
-
|
|
|
- if announced_host is None:
|
|
|
- announced_host = self._choose_announced_host()
|
|
|
- self.announced_host = announced_host
|
|
|
- self.channel_options = channel_options
|
|
|
self.daemon = daemon
|
|
|
|
|
|
self._averaged_tensors = tuple(averaged_tensors)
|
|
@@ -179,17 +158,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
|
|
|
|
|
|
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with daemon
|
|
|
- self._port = mp.Value(ctypes.c_uint32, 0) # assigned when averager starts, accessible via self.port
|
|
|
|
|
|
self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
|
|
|
if allow_state_sharing is None:
|
|
|
allow_state_sharing = not client_mode and not auxiliary
|
|
|
self.allow_state_sharing = allow_state_sharing
|
|
|
|
|
|
- self._averager_endpoint: Optional[Endpoint] = None
|
|
|
- if self.client_mode:
|
|
|
- self._averager_endpoint = f"client::{uuid.uuid4()}"
|
|
|
-
|
|
|
self.ready = mp.Event() # whether the averager process has started (and ready for incoming requests)
|
|
|
# note: we create a background thread weakref and with daemon=True to ensure garbage collection
|
|
|
background_fetcher = threading.Thread(
|
|
@@ -201,22 +175,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
if start:
|
|
|
self.run_in_background(await_ready=True)
|
|
|
|
|
|
- def _choose_announced_host(self) -> Hostname:
|
|
|
- announced_host = strip_port(self.listen_on).strip("[]") # Stripping square brackets for IPv6
|
|
|
- if ip_address(announced_host) not in [ip_address("0.0.0.0"), ip_address("::")]:
|
|
|
- return announced_host
|
|
|
-
|
|
|
- maddrs = self.dht.get_visible_maddrs()
|
|
|
- announced_host = choose_ip_address(maddrs)
|
|
|
- logger.info(
|
|
|
- f"Choosing IP {announced_host} as endpoint for DecentralizedAverager " f"from visible multiaddrs {maddrs}"
|
|
|
- )
|
|
|
- return announced_host
|
|
|
-
|
|
|
- @property
|
|
|
- def port(self) -> Optional[Port]:
|
|
|
- return self._port.value if self._port.value != 0 else None
|
|
|
-
|
|
|
@property
|
|
|
def allow_state_sharing(self) -> bool:
|
|
|
"""if set to True, other peers can download this peer's state"""
|
|
@@ -230,12 +188,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
self._allow_state_sharing.value = value
|
|
|
|
|
|
@property
|
|
|
- def endpoint(self) -> Optional[Endpoint]:
|
|
|
- if self._averager_endpoint is None and not self.client_mode:
|
|
|
- assert self.port is not None, "Averager is not running yet"
|
|
|
- self._averager_endpoint = f"{self.announced_host}:{self.port}"
|
|
|
- logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
|
|
|
- return self._averager_endpoint
|
|
|
+ def endpoint(self) -> Endpoint:
|
|
|
+ return self.p2p.id
|
|
|
|
|
|
def __repr__(self):
|
|
|
return f"{self.__class__.__name__}({self.endpoint})"
|
|
@@ -257,15 +211,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
|
|
|
|
|
|
async def _run():
|
|
|
- grpc.aio.init_grpc_aio()
|
|
|
-
|
|
|
if not self.client_mode:
|
|
|
- self._server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
|
|
|
- averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, self._server)
|
|
|
- found_port = self._server.add_insecure_port(self.listen_on)
|
|
|
- assert found_port != 0, f"Failed to listen to {self.listen_on}"
|
|
|
- self._port.value = found_port
|
|
|
- await self._server.start()
|
|
|
+ await self.add_p2p_handlers(self.p2p)
|
|
|
else:
|
|
|
logger.debug(f"The averager is running in client mode.")
|
|
|
|
|
@@ -394,11 +341,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
MatchmakingException,
|
|
|
AssertionError,
|
|
|
StopAsyncIteration,
|
|
|
- InternalError,
|
|
|
asyncio.CancelledError,
|
|
|
asyncio.InvalidStateError,
|
|
|
- grpc.RpcError,
|
|
|
- grpc.aio.AioRpcError,
|
|
|
+ P2PHandlerError,
|
|
|
) as e:
|
|
|
time_elapsed = get_dht_time() - start_time
|
|
|
if not allow_retries or (timeout is not None and timeout < time_elapsed):
|
|
@@ -496,14 +441,14 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
self.lock_averaged_tensors.release()
|
|
|
|
|
|
async def rpc_join_group(
|
|
|
- self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
|
|
|
+ self, request: averaging_pb2.JoinRequest, context: P2PContext
|
|
|
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
|
|
|
"""accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
|
|
|
async for response in self._matchmaking.rpc_join_group(request, context):
|
|
|
yield response
|
|
|
|
|
|
async def rpc_aggregate_part(
|
|
|
- self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
|
|
|
+ self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
|
|
|
) -> AsyncIterator[averaging_pb2.AveragingData]:
|
|
|
"""a groupmate sends us a part of his tensor; we should average it with other peers and return the result"""
|
|
|
request = await anext(stream)
|
|
@@ -539,7 +484,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
await asyncio.sleep(self._matchmaking.averaging_expiration)
|
|
|
|
|
|
async def rpc_download_state(
|
|
|
- self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
|
|
|
+ self, request: averaging_pb2.DownloadRequest, _: P2PContext
|
|
|
) -> AsyncIterator[averaging_pb2.DownloadData]:
|
|
|
"""
|
|
|
Get the up-to-date trainer state from a peer.
|
|
@@ -610,9 +555,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
logger.info(f"Downloading parameters from peer {peer}")
|
|
|
stream = None
|
|
|
try:
|
|
|
- stub = ChannelCache.get_stub(
|
|
|
- peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True, options=self.channel_options
|
|
|
- )
|
|
|
+ stub = self.get_stub(self.p2p, peer)
|
|
|
stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
|
|
|
current_tensor_parts, tensors = [], []
|
|
|
async for message in stream:
|