Browse Source

Group AllReduce protocol (#119)

This is the first part of #115 that implements averaging tensors in a (pre-determined) group of peers


Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 năm trước cách đây
mục cha
commit
0595f4af90

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.11'
+__version__ = '0.8.12'

+ 1 - 0
hivemind/client/__init__.py

@@ -1,2 +1,3 @@
 from hivemind.client.expert import RemoteExpert
 from hivemind.client.moe import RemoteMixtureOfExperts
+from hivemind.client.averager import DecentralizedAverager

+ 358 - 0
hivemind/client/allreduce.py

@@ -0,0 +1,358 @@
+""" This file contains a state machine that defines allreduce protocol used in DecentralizedAverager """
+from __future__ import annotations
+import asyncio
+import random
+from dataclasses import asdict
+from typing import Set, Optional, Sequence, Tuple, Dict, AsyncIterator
+from enum import Enum, auto
+
+import grpc
+import torch
+
+from hivemind.dht import DHTID, DHTExpiration
+from hivemind.utils import Endpoint, get_logger, MSGPackSerializer
+from hivemind.utils import TensorDescriptor, deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
+
+logger = get_logger(__name__)
+
+# flavour types
+GroupID = bytes
+
+
+class ProtocolState(Enum):
+    LOOKING_FOR_GROUP = auto()   # i want to run averaging, but haven't found any peers yet
+    LEADER_WAITING_FOR_PEERS = auto()     # i am a leader, waiting for more peers to join
+    FOLLOWER_WAITING_FOR_LEADER = auto()  # i am a follower, my leader is assembling the group
+    RUNNING_ALLREDUCE = auto()   # we are currently exchanging tensors in a group
+    FINISHED_NORMALLY = auto()   # we ran allreduce and finished without errors
+    GROUP_DISBANDED = auto()     # leader disbanded the group before we began allreduce
+    ERROR = auto()               # someone (maybe i) messed up and we can't recover
+    CANCELLED = auto()           # i have unilaterally cancelled GroupAllreduce
+
+
+class GroupAllReduce:
+    """
+    An internal class that keeps track of one group allreduce for DecentralizedAverager.
+    GroupAllReduce is meant to be modified with methods, no direct variable assignments is allowed outside of debugging.
+
+    :param endpoint: my endpoint, as seen by the group leader
+    :param expiration: the time after which the group should begin allreduce or be disbanded
+    :param tensors: a sequence of torch tensors that i intend to average with peers
+    """
+    compression_type = runtime_pb2.NONE
+
+    def __init__(self, endpoint: Endpoint, expiration: DHTExpiration, tensors: Sequence[torch.Tensor]):
+        assert all(tensor.dtype == torch.float32 and tensor.device == torch.device('cpu') for tensor in tensors)
+        self.local_tensors = tensors
+        self.state = ProtocolState.LOOKING_FOR_GROUP
+        self.info = averaging_pb2.PeerInfo(endpoint=endpoint, expiration=expiration,
+                                           schema_hash=compute_schema_hash(tensors))
+
+        self.leader_endpoint: Optional[Endpoint] = None
+        self.group_id: Optional[GroupID] = None  # a unique identifier of this one group all-reduce
+        self.max_size = float('inf')  # maximum group size, only enforced for group leader
+
+        # populated when assembling a group
+        self.group_endpoints_set: Set[Endpoint] = set()
+        self.assembled_group: asyncio.Future[Sequence[Endpoint]] = asyncio.Future()  # final ordered endpoints
+        self.concurrent_requests_lock = asyncio.Lock()  # lock inbound/outbound requests to join group
+
+        # populated when running allreduce
+        self.accumulator: Optional[torch.Tensor] = None   # the sum of averaged tensors so far, init with zeros
+        self.accumulated_from: Set[Endpoint] = set()      # peers that we have accumulated our part from
+        self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()
+
+        self.average_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers
+        self.averaged_tensors: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({self.info.endpoint}, {self.state})"
+
+    def __await__(self):
+        return self.averaged_tensors.__await__()
+
+    def start_new_group(self, max_size: Optional[int] = None):
+        """ Create new group with a random id, become its leader and the only participant """
+        assert self.state == ProtocolState.LOOKING_FOR_GROUP
+        self.group_id = DHTID.generate().to_bytes()
+        # note: we generate group_id as DHTID for convenience. Do not assume that it has DHTID-like properties
+        logger.debug(f"{self} - starting a new group as a leader. Group id: {self.group_id}")
+        self.state = ProtocolState.LEADER_WAITING_FOR_PEERS
+        self.leader_endpoint = self.info.endpoint
+        self.group_endpoints_set = {self.info.endpoint}
+        if max_size is not None:
+            self.max_size = max_size
+
+    @property
+    def group_size(self):
+        assert self.state in (ProtocolState.LEADER_WAITING_FOR_PEERS, ProtocolState.RUNNING_ALLREDUCE)
+        return len(self.group_endpoints_set)
+
+    def join_group(self, leader_endpoint: Endpoint, group_id: GroupID):
+        """ After you were accepted by a leader, create your local instance using the metadata he sent """
+        self.group_id, self.leader_endpoint = group_id, leader_endpoint
+        logger.debug(f"{self} - joining the group of {leader_endpoint}. Group id: {self.group_id}")
+        self.state = ProtocolState.FOLLOWER_WAITING_FOR_LEADER
+
+    def add_peer_to_group(self, follower: Endpoint):
+        """ Add peer to a group, assuming that he can be added (self.get_reasons_to_reject(peer) is None) """
+        assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS
+        assert follower not in self.group_endpoints_set
+        self.group_endpoints_set.add(follower)
+        logger.debug(f"{self} - adding {follower} to my group. New size = {self.group_size}")
+        if self.group_size > self.max_size:
+            logger.warning(f"{self} - group size ({self.group_size}) exceeded max size ({self.max_size})")
+
+    def remove_peer_from_group(self, follower: Endpoint):
+        """ Remove a disconnected peer from current group """
+        assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS
+        assert follower in self.group_endpoints_set and follower != self.leader_endpoint
+        self.group_endpoints_set.remove(follower)
+        logger.info(f"{self} - removed {follower} from the group. New size = {self.group_size}")
+
+    def disband_group(self):
+        assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS and self.group_size == 1
+        logger.info(f"{self} - disbanded group (reason = empty)")
+        self.state = ProtocolState.LOOKING_FOR_GROUP
+
+    def leader_begin_allreduce(self) -> averaging_pb2.MessageFromLeader:
+        """ As a leader, distribute allreduce metadata to peers and start allreduce """
+        assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS and self.group_size > 1
+        logger.debug(f"{self} - initiating allreduce for {self.group_endpoints_set} peers.")
+        ordered_group_endpoints = list(self.group_endpoints_set)
+        random.shuffle(ordered_group_endpoints)
+        self.assembled_group.set_result(ordered_group_endpoints)
+        self.state = ProtocolState.RUNNING_ALLREDUCE
+
+    def follower_begin_allreduce(self, ordered_group_endpoints: Sequence[Endpoint]):
+        """ As a follower, receive the final list of peers from the leader and begin sending data around """
+        assert self.state == ProtocolState.FOLLOWER_WAITING_FOR_LEADER and self.info.endpoint in ordered_group_endpoints
+        logger.debug(f"{self} - received peer order from the leader, beginning allreduce.")
+        self.group_endpoints_set = set(ordered_group_endpoints)
+        self.assembled_group.set_result(ordered_group_endpoints)
+        self.state = ProtocolState.RUNNING_ALLREDUCE
+
+    async def accumulate(self, source: Endpoint, part: torch.Tensor) -> torch.Tensor:
+        """ Add vector part to accumulator, wait for all other vectors to be added, return the average """
+        assert source not in self.accumulated_from, "duplicate endpoint, already received that part"
+        assert self.accumulator is None or self.accumulator.shape == part.shape
+        logger.debug(f"{self} - accumulated part from {source}")
+
+        self.accumulator = part if self.accumulator is None else self.accumulator.add_(part)
+        self.accumulated_from.add(source)
+
+        ordered_group_endpoints = await self.assembled_group
+        assert len(self.accumulated_from) <= len(ordered_group_endpoints)
+        if len(self.accumulated_from) == len(ordered_group_endpoints):
+            self.averaged_part.set_result(self.accumulator.div_(len(self.accumulated_from)))
+
+        return await self.averaged_part
+
+    def _get(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
+        """ TODO this function is deprecated and will be replaced by a shared channel cache """
+        channel = grpc.aio.insecure_channel(peer)
+        return averaging_pb2_grpc.DecentralizedAveragingStub(channel)
+
+    async def handle_join_request(self, request: averaging_pb2.PeerInfo
+                                  ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
+        """ accept or reject a join request; if accepted, run him through allreduce steps """
+        should_remove_peer = False
+        try:
+            # stage 1: check if there is a reason to reject a peer outright
+            if not is_valid_join_request(request):
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
+                return
+            if self.info.expiration > (request.expiration or float('inf')):
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
+            elif request.schema_hash != self.info.schema_hash:
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
+                return
+            elif request.endpoint == self.info.endpoint or request.endpoint in (self.group_endpoints_set or ()):
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
+                return
+            elif self.state == ProtocolState.FOLLOWER_WAITING_FOR_LEADER:
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER,
+                                                      suggested_leader=self.leader_endpoint)
+                return
+            elif self.state == ProtocolState.RUNNING_ALLREDUCE or len(self.accumulated_from) > 0:
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ALREADY_RUNNING)
+                return
+            if self.state == ProtocolState.LEADER_WAITING_FOR_PEERS and self.group_size >= self.max_size:
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
+                return
+
+            # stage 2: add peer to group, optionally start a new one
+            async with self.concurrent_requests_lock:
+                if self.state == ProtocolState.LOOKING_FOR_GROUP:
+                    self.start_new_group()
+
+                assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS
+
+                self.add_peer_to_group(request.endpoint)
+                should_remove_peer = True
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED, group_id=self.group_id)
+
+            if self.group_size >= self.max_size:
+                self.leader_begin_allreduce()
+
+            # stage 3: wait for the group to be assembled and return
+            ordered_group_endpoints = await self.assembled_group
+            if ordered_group_endpoints is not None:
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.BEGIN_ALLREDUCE,
+                                                      ordered_group_endpoints=ordered_group_endpoints)
+            else:
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
+
+        except Exception as e:
+            logger.exception(e)
+            yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
+
+        finally:  # this code is guaranteed to run if the iterator is destroyed prematurely
+            if should_remove_peer:
+                self.remove_peer_from_group(request.endpoint)
+                if self.group_size <= 1:
+                    self.set_exception(ValueError("All peers have left"))
+
+    async def request_join_group(self, leader: Endpoint
+                                 ) -> Optional[grpc.aio.UnaryStreamCall[averaging_pb2.MessageFromLeader]]:
+        """ request a given peer to be your leader for allreduce. if accepted, return a grpc stream """
+        assert self.state == ProtocolState.LOOKING_FOR_GROUP
+        try:
+            async with self.concurrent_requests_lock:
+                stream = self._get(leader).rpc_group_allreduce(self.info)
+                message = await stream.read()
+                logger.debug(f"{self} - requested {leader} to be my leader, received "
+                             f"{averaging_pb2.MessageCode.Name(message.code)}")
+                if message.code == averaging_pb2.ACCEPTED:
+                    self.join_group(leader, message.group_id)
+                    return stream
+
+        except Exception as e:
+            self.set_exception(e)
+
+    async def wait_for_allreduce(self, stream: grpc.aio.UnaryStreamCall[averaging_pb2.MessageFromLeader]) -> bool:
+        """ the second part of request_join_group, return True if started allreduce, False if failed or disbanded """
+        try:
+            message = await stream.read()
+            if message.code == averaging_pb2.BEGIN_ALLREDUCE:
+                logger.debug(f"{self} - leader triggered allreduce")
+                assert all(isinstance(p, Endpoint) for p in message.ordered_group_endpoints)
+                self.follower_begin_allreduce(message.ordered_group_endpoints)
+                return True
+            else:
+                logger.debug(f"{self} - leader sent {averaging_pb2.MessageCode.Name(message.code)}, leaving group")
+                self.state = ProtocolState.GROUP_DISBANDED
+                return False
+        except Exception as e:
+            self.set_exception(e)
+            return False
+
+    async def run_allreduce(self) -> Sequence[torch.Tensor]:
+        """ send allreduce requests to all peers and collect results, return the averaged tensor """
+        assert self.state == ProtocolState.RUNNING_ALLREDUCE
+        ordered_group_endpoints = await self.assembled_group
+        ordered_local_parts = split_into_parts(self.local_tensors, group_size=self.group_size)
+
+        async def send_part(peer_endpoint: Endpoint, local_part: torch.Tensor):
+            if peer_endpoint == self.info.endpoint:
+                self.average_tensor_parts[peer_endpoint] = await self.accumulate(peer_endpoint, local_part)
+            else:
+                serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
+                response = await self._get(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
+                    group_id=self.group_id, endpoint=self.info.endpoint, tensor_part=serialized_tensor_part))
+
+                if response.code == averaging_pb2.ACCEPTED:
+                    self.average_tensor_parts[peer_endpoint] = deserialize_torch_tensor(response.tensor_part)
+                else:
+                    raise ValueError(f"peer {peer_endpoint} replied {averaging_pb2.MessageCode.Name(response.code)}")
+
+            if len(self.average_tensor_parts) >= len(self.group_endpoints_set):
+                ordered_parts = [self.average_tensor_parts[peer] for peer in ordered_group_endpoints]
+                tensor_shapes = [tensor.shape for tensor in self.local_tensors]
+                self.averaged_tensors.set_result(restore_from_parts(ordered_parts, tensor_shapes))
+
+        try:
+            await asyncio.gather(*map(send_part, ordered_group_endpoints, ordered_local_parts))
+            return await self.averaged_tensors
+        except Exception as e:
+            code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
+
+            async def send_error_to_peer(peer_endpoint):
+                await self._get(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
+                    group_id=self.group_id, endpoint=self.info.endpoint, code=code))
+            for peer_endpoint in ordered_group_endpoints:
+                asyncio.create_task(send_error_to_peer(peer_endpoint))
+            if code == averaging_pb2.CANCELLED:
+                self.cancel()
+            else:
+                self.set_exception(e)
+            raise
+
+    async def handle_accumulate_request(self, request: averaging_pb2.AveragingData) -> averaging_pb2.AveragingData:
+        """ respond to an incoming rpc_accumulate_part """
+        if self.state not in (ProtocolState.RUNNING_ALLREDUCE, ProtocolState.FOLLOWER_WAITING_FOR_LEADER):
+            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
+        elif request.group_id != self.group_id:
+            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
+        elif request.endpoint in self.accumulated_from:
+            return averaging_pb2.AveragingData(code=averaging_pb2.DUPLICATE_ENDPOINT)
+
+        if request.code in (averaging_pb2.INTERNAL_ERROR, averaging_pb2.CANCELLED):
+            self.set_exception(ValueError(f"{request.endpoint} sent {averaging_pb2.MessageCode.Name(request.code)}"))
+            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
+
+        try:
+            received_part = deserialize_torch_tensor(request.tensor_part)
+            averaged_part = await self.accumulate(request.endpoint, received_part)
+            serialized = serialize_torch_tensor(averaged_part, request.tensor_part.compression, allow_inplace=False)
+            return averaging_pb2.AveragingData(code=averaging_pb2.ACCEPTED, tensor_part=serialized)
+        except asyncio.CancelledError:
+            self.cancel()
+            return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
+        except Exception as e:
+            self.set_exception(e)
+            return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+
+    def cancel(self):
+        logger.debug(f"{self} - cancelled")
+        self.state = ProtocolState.CANCELLED
+        for future in self.assembled_group, self.averaged_part, self.averaged_tensors:
+            future.cancel()
+
+    def set_exception(self, exception: Exception):
+        logger.debug(f"{self} - {exception}")
+        self.state = ProtocolState.ERROR
+        for future in self.assembled_group, self.averaged_part, self.averaged_tensors:
+            future.set_exception(exception)
+
+
+def split_into_parts(tensors: Sequence[torch.Tensor], group_size: int) -> Tuple[torch.Tensor]:
+    """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
+    flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
+    chunk_slices = torch.linspace(start=0, end=len(flat_tensor), steps=group_size + 1, dtype=torch.int64)
+    chunk_slices[-1] = len(flat_tensor)
+    return tuple(torch.as_tensor(flat_tensor[chunk_slices[i]: chunk_slices[i + 1]]) for i in range(group_size))
+
+
+def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:
+    """ restores the original tensor shapes from chunks obtained by split_into_chunks """
+    flat_tensor = torch.cat(list(chunks))
+    result_sizes = tuple(map(torch.Size.numel, shapes))
+    flat_original_tensors = torch.split_with_sizes(flat_tensor, result_sizes)
+    return tuple(map(torch.Tensor.reshape, flat_original_tensors, shapes))
+
+
+def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
+    """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
+    schema_dicts = [{field_name: str(field_value)
+                    for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
+                    for tensor in tensors]
+    return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()
+
+
+def is_valid_join_request(request: averaging_pb2.PeerInfo) -> bool:
+    assert len(request.ListFields()) == 3, "this function assumes JoinRequest has three fields, it should be updated"
+    return (isinstance(request.schema_hash, bytes) and
+            isinstance(request.expiration, DHTExpiration) and
+            isinstance(request.endpoint, Endpoint))

+ 184 - 0
hivemind/client/averager.py

@@ -0,0 +1,184 @@
+""" A background process that averages your tensors with peers """
+
+from __future__ import annotations
+
+import ctypes
+from typing import Sequence, Optional, Tuple, Any, Union, Awaitable, Dict
+from concurrent.futures.thread import ThreadPoolExecutor
+import multiprocessing as mp
+import asyncio
+
+import torch
+import uvloop
+import grpc
+
+import hivemind
+from hivemind.dht import get_dht_time, DHTExpiration
+from hivemind.utils import get_logger, Endpoint, Port, MPFuture
+from hivemind.client.allreduce import GroupAllReduce, GroupID
+from hivemind.proto import averaging_pb2, averaging_pb2_grpc
+
+logger = get_logger(__file__)
+
+
+class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
+    """
+    **Warning!** Decentralized averager is in active development, some critical functionality is still underway
+
+    Gating function averaging service. A trainer can run this service in background to periodically average his gating
+    function with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
+    group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
+    Why averaging is valid: see https://github.com/learning-at-home/hivemind/issues/95#issuecomment-688806705
+    On global convergence: see https://github.com/learning-at-home/hivemind/issues/95#issuecomment-717719400
+
+    :param averaged_tensors: a sequence of pytorch tensors that will be averaged in each all-reduce
+    :param dht: a DHT node that will be used to find groups
+    :param start: if True, starts the background process immediately
+    :param timeout: consider allreduce failed if there was no activity for this many **seconds**
+    :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
+            if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
+    :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
+    :param receiver_threads: uses this many threads to await on input pipe. Default = 1 should be enough in most cases
+    :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
+          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
+    :param kwargs: extra parameters forwarded to in grpc.aio.server
+    You can perform averaging using DecentralizedOptimizer (see below) or by manually running each step as such:
+
+    >> TODO add a working example
+    """
+
+    def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
+                 max_size: int = None, timeout: float = 15, listen: bool = True, listen_on: Endpoint = '0.0.0.0:*',
+                 receiver_threads: int = 1, channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
+        super().__init__()
+        self.dht = dht
+        self.server_opts = listen, listen_on, receiver_threads, kwargs
+        self.max_size = max_size if max_size is not None else float('inf')
+        self.timeout = timeout
+        self.channel_options = channel_options
+        self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
+        self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
+        self._pending_groups: Dict[GroupID, GroupAllReduce] = {}
+        self._lock_forming_a_group: Optional[asyncio.Lock] = None
+        self.ready = mp.Event()
+
+        self.averaged_tensors = tuple(averaged_tensors)
+        for tensor in self.averaged_tensors:
+            assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
+            tensor.share_memory_()
+
+        if start:
+            self.run_in_background(await_ready=True)
+
+    @property
+    def port(self) -> Optional[Port]:
+        return self._port.value if self._port.value != 0 else None
+
+    def run(self):
+        """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
+        if asyncio.get_event_loop().is_running():
+            asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
+
+        uvloop.install()
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(loop)
+
+        listen, listen_on, receiver_threads, server_kwargs = self.server_opts
+        pipe_awaiter = ThreadPoolExecutor(receiver_threads)
+        self._lock_forming_a_group = asyncio.Lock()
+
+        async def _run():
+            if listen:
+                grpc.aio.init_grpc_aio()
+                server = grpc.aio.server(**server_kwargs)
+                averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
+                found_port = server.add_insecure_port(listen_on)
+                assert found_port != 0, f"Failed to listen to {listen_on}"
+                self._port.value = found_port
+                await server.start()
+                self.ready.set()
+            else:
+                raise NotImplementedError("Client-only averaging is not implemented yet.")
+
+            while True:
+                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
+                asyncio.create_task(getattr(self, method)(*args, **kwargs))
+
+        loop.run_until_complete(_run())
+
+    def run_in_background(self, await_ready=True, timeout=None):
+        """
+        Starts averager in a background process. if await_ready, this method will wait until background dht
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready and not self.ready.wait(timeout=timeout):
+            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
+
+    def shutdown(self) -> None:
+        """ Shut down the averager process """
+        # TODO notify peers before terminating
+        if self.is_alive():
+            self.terminate()
+        else:
+            logger.warning("DHT shutdown has no effect: the process is not alive")
+
+    def group_allreduce(self, my_endpoint: Endpoint, leader_endpoint: Optional[Endpoint] = None,
+                        return_future=False) -> Union[Sequence[torch.Tensor], Awaitable[Sequence[torch.Tensor]]]:
+        """
+        Set up the averager to look for a group and run all-reduce once, optionally await and return outcome
+
+        :note: this function implemented for debugging and will be removed in future versions
+        :param my_endpoint: public endpoint of this averager
+        :param leader_endpoint: if specified, attempts to join this peer's group
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        """
+        expiration = get_dht_time() + self.timeout
+        assert isinstance(expiration, DHTExpiration)
+
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_group_allreduce', [], dict(my_endpoint=my_endpoint, expiration=expiration,
+                                                     leader_endpoint=leader_endpoint, future=_future)))
+        return future if return_future else future.result()
+
+    async def _group_allreduce(self, *, my_endpoint: Endpoint, expiration: DHTExpiration,
+                               leader_endpoint: Optional[Endpoint], future: MPFuture):
+        group_allreduce = GroupAllReduce(my_endpoint, expiration, self.averaged_tensors)
+        try:
+            if leader_endpoint is None:
+                async with self._lock_forming_a_group:
+                    group_allreduce.start_new_group(max_size=self.max_size)
+                    self._forming_group = self._pending_groups[group_allreduce.group_id] = group_allreduce
+                    await asyncio.wait_for(group_allreduce.assembled_group, expiration - get_dht_time())
+
+                future.set_result(await group_allreduce.run_allreduce())
+            else:
+                async with self._lock_forming_a_group:
+                    stream = await group_allreduce.request_join_group(leader_endpoint)
+                    self._forming_group = self._pending_groups[group_allreduce.group_id] = group_allreduce
+
+                started_allreduce = await group_allreduce.wait_for_allreduce(stream)
+                if started_allreduce:
+                    future.set_result(await group_allreduce.run_allreduce())
+                else:
+                    future.set_exception(ValueError(f"Rejected by {leader_endpoint}"))
+
+        except Exception as e:
+            future.set_exception(e)
+        finally:
+            _ = self._pending_groups.pop(group_allreduce.group_id, None)
+            if group_allreduce is self._forming_group:
+                self._forming_group = None
+
+    async def rpc_group_allreduce(self, request: averaging_pb2.PeerInfo, context: grpc.ServicerContext):
+        """ A peer wants me to be his leader. I will coordinate his actions with the rest of my group. Maybe. """
+        if self._forming_group is None:
+            yield averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
+            return
+        async for message in self._forming_group.handle_join_request(request):
+            yield message
+
+    async def rpc_aggregate_part(self, request: averaging_pb2.AveragingData, context: grpc.ServicerContext):
+        if request.group_id not in self._pending_groups:
+            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
+        return await self._pending_groups[request.group_id].handle_accumulate_request(request)

+ 47 - 0
hivemind/proto/averaging.proto

@@ -0,0 +1,47 @@
+syntax = "proto3";
+import "runtime.proto";
+
+
+// Runs alongside each trainer to perform gating function averaging every now and then. Read more: client/averaging.py
+service DecentralizedAveraging {
+  rpc rpc_group_allreduce(PeerInfo) returns (stream MessageFromLeader);  // assemble a group and run all-reduce
+  rpc rpc_aggregate_part(AveragingData) returns (AveragingData);  // send my local shard => get aggregated shard
+}
+
+message PeerInfo {
+  string endpoint = 1;          // A follower accepts incoming allreduce requests at this address
+  bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
+  double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
+}
+
+enum MessageCode {
+  // response to join request
+  ACCEPTED = 0;              // "I accept you in my group, you will not commit to responding to me."
+  NOT_A_LEADER = 1;          // "I am not a group a leader. Go ask my leader instead."
+  ALREADY_RUNNING = 2;       // "My group has already began merging. Here's the group leader."
+  NOT_LOOKING_FOR_GROUP = 3; // "I'm not available at the moment. Please, get lost."
+  BAD_EXPIRATION_TIME = 4;   // "I will not accept you. I cannot guarantee that we begin before you expire."
+  BAD_SCHEMA_HASH = 5;       // "I will not accept you. I am not averaging the samy type of tensors as you."
+  DUPLICATE_ENDPOINT = 6;    // "I will not accept you, i already have exactly the same endpoint in my current group"
+  GROUP_IS_FULL = 7;         // "I will not accept you, my group already contains too many peers"
+  BEGIN_ALLREDUCE = 8;       // "We can begin allreduce now. These are your peers."
+  GROUP_DISBANDED = 9;       // "The group is closed. Go find another group."
+  UNKNOWN_GROUP_ID = 10;     // "Your request uses a group id that doesn't match with any group i know"
+  PROTOCOL_VIOLATION = 11;   // "One of peers did something in violation of the allreduce protocol"
+  INTERNAL_ERROR = 12;       // "We encountered an unexpected error on our side"
+  CANCELLED = 13;            // "A peer cancelled allreduce while averaging"
+}
+
+message MessageFromLeader {
+  MessageCode code = 1;
+  bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
+  string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
+  repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
+}
+
+message AveragingData {
+  MessageCode code = 1;     // in case of a protocol violation, this will be the error message
+  bytes group_id = 2;        // a unique group identifier, same as in MessageFromLeader
+  string endpoint = 3;      // sender's rpc endpoint, used for coordination
+  Tensor tensor_part = 4;    // either peer's local tensor part (rpc input) or group average of this part (rpc output)
+}

+ 117 - 0
tests/test_averaging.py

@@ -0,0 +1,117 @@
+import asyncio
+import random
+import time
+from itertools import product
+
+import torch
+import pytest
+import hivemind
+from hivemind.client.allreduce import GroupAllReduce, split_into_parts, restore_from_parts
+from hivemind.utils import LOCALHOST
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_allreduce_direct():
+    # WARNING! this test uses an early interface that will change by the time DecentralizedAverager is finished
+
+    dht = hivemind.DHT(start=True)
+
+    tensors1 = [torch.randn(123), torch.zeros(3)]
+    tensors2 = [torch.rand(123), torch.ones(3)]
+    tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
+
+    reference = [(tensors1[i] + tensors2[i] + tensors3[i]) / 3 for i in range(len(tensors1))]
+
+    averager1 = hivemind.DecentralizedAverager(tensors1, dht=dht, start=True, max_size=3, timeout=5)
+    averager2 = hivemind.DecentralizedAverager(tensors2, dht=dht, start=True, max_size=3, timeout=5)
+    averager3 = hivemind.DecentralizedAverager(tensors3, dht=dht, start=True, max_size=3, timeout=5)
+
+    future1 = averager1.group_allreduce(my_endpoint=f"{LOCALHOST}:{averager1.port}",
+                                        leader_endpoint=None, return_future=True)
+    time.sleep(0.1)
+
+    future2 = averager2.group_allreduce(my_endpoint=f"{LOCALHOST}:{averager2.port}",
+                                        leader_endpoint=f"{LOCALHOST}:{averager1.port}",
+                                        return_future=True)
+
+    future3 = averager3.group_allreduce(my_endpoint=f"{LOCALHOST}:{averager3.port}",
+                                        leader_endpoint=f"{LOCALHOST}:{averager1.port}",
+                                        return_future=True)
+
+    for future in future1, future2, future3:
+        for ref, our in zip(reference, await future):
+            assert torch.allclose(ref, our)
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_allreduce_protocol():
+    """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
+    peers = "alice", "bob", "carol"
+    expiration_offsets = 4, 0, 1
+
+    tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
+                       for i, peer in enumerate(peers)}
+
+    alice, bob, carol = allreduce_protocols = [
+        GroupAllReduce(endpoint=peer, expiration=hivemind.get_dht_time() + offset, tensors=tensors_by_peer[peer])
+        for peer, offset in zip(peers, expiration_offsets)]
+
+    bob.start_new_group()
+    bob.add_peer_to_group(alice.info.endpoint)
+    alice.join_group(bob, bob.group_id)
+    bob.add_peer_to_group(carol.info.endpoint)
+    carol.join_group(carol, bob.group_id)
+
+    bob.leader_begin_allreduce()
+    ordered_group_endpoints = await bob.assembled_group
+    assert len(ordered_group_endpoints) == len(peers)
+
+    carol.follower_begin_allreduce(ordered_group_endpoints)
+    alice.follower_begin_allreduce(ordered_group_endpoints)
+
+    chunks_by_peer = {protocol.info.endpoint: {
+        peer: part for peer, part in zip(peers, split_into_parts(protocol.local_tensors, len(ordered_group_endpoints)))
+    } for protocol in allreduce_protocols}
+
+    all_pairs = list(product(allreduce_protocols, peers))
+    random.shuffle(all_pairs)
+    await asyncio.gather(*(
+        peer_allreduce.accumulate(source_peer, chunks_by_peer[source_peer][peer_allreduce.info.endpoint])
+        for peer_allreduce, source_peer in all_pairs))
+
+    averaged_parts = await asyncio.gather(*(protocol.averaged_part for protocol in allreduce_protocols))
+    tensor_shapes = [tensor.shape for tensor in alice.local_tensors]
+    averaged_tensors = restore_from_parts(averaged_parts, tensor_shapes)
+
+    reference_tensors = [
+        sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
+        for i in range(len(tensors_by_peer[peers[0]]))
+    ]
+
+    assert len(averaged_tensors) == len(reference_tensors)
+    assert all(map(torch.allclose, averaged_tensors, reference_tensors))
+
+
+@pytest.mark.forked
+def test_chunks():
+    for i in range(100):
+        tensors = []
+        for i in range(random.randint(1, 5)):
+            ndim = random.randint(0, 4)
+            shape = torch.Size([random.randint(0, 16) for _ in range(ndim)])
+            make_tensor = random.choice([torch.rand, torch.randn, torch.zeros, torch.ones])
+            tensors.append(make_tensor(shape))
+
+        total_size = sum(map(torch.Tensor.numel, tensors))
+        if total_size == 0:
+            continue
+        num_chunks = random.randint(1, min(1000, sum(x.numel() for x in tensors)))
+        chunks = split_into_parts(tensors, group_size=num_chunks)
+        assert len(chunks) == num_chunks
+        shapes = [tensor.shape for tensor in tensors]
+        restored = restore_from_parts(chunks, shapes)
+        assert len(restored) == len(tensors)
+        assert all(new.shape == old.shape for new, old in zip(restored, tensors))
+        assert all(torch.allclose(new, old) for new, old in zip(restored, tensors))