Prechádzať zdrojové kódy

Implement averaging parameters over DHT (2/3)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 rokov pred
rodič
commit
eb93789ac6

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.19'
+__version__ = '0.8.20'

+ 1 - 1
hivemind/client/__init__.py

@@ -1,3 +1,3 @@
 from hivemind.client.expert import RemoteExpert
 from hivemind.client.moe import RemoteMixtureOfExperts
-from hivemind.client.averager import DecentralizedAverager
+from hivemind.client.averaging import DecentralizedAverager

+ 0 - 356
hivemind/client/allreduce.py

@@ -1,356 +0,0 @@
-""" This file contains a state machine that defines allreduce protocol used in DecentralizedAverager """
-from __future__ import annotations
-import asyncio
-import random
-from dataclasses import asdict
-from typing import Set, Optional, Sequence, Tuple, Dict, AsyncIterator
-from enum import Enum, auto
-
-import grpc
-import torch
-
-from hivemind.dht import DHTID, DHTExpiration
-from hivemind.utils import Endpoint, get_logger, MSGPackSerializer
-from hivemind.utils import TensorDescriptor, deserialize_torch_tensor, serialize_torch_tensor, ChannelCache
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
-
-logger = get_logger(__name__)
-
-# flavour types
-GroupID = bytes
-
-
-class ProtocolState(Enum):
-    LOOKING_FOR_GROUP = auto()   # i want to run averaging, but haven't found any peers yet
-    LEADER_WAITING_FOR_PEERS = auto()     # i am a leader, waiting for more peers to join
-    FOLLOWER_WAITING_FOR_LEADER = auto()  # i am a follower, my leader is assembling the group
-    RUNNING_ALLREDUCE = auto()   # we are currently exchanging tensors in a group
-    FINISHED_NORMALLY = auto()   # we ran allreduce and finished without errors
-    GROUP_DISBANDED = auto()     # leader disbanded the group before we began allreduce
-    ERROR = auto()               # someone (maybe i) messed up and we can't recover
-    CANCELLED = auto()           # i have unilaterally cancelled GroupAllreduce
-
-
-class GroupAllReduce:
-    """
-    An internal class that keeps track of one group allreduce for DecentralizedAverager.
-    GroupAllReduce is meant to be modified with methods, no direct variable assignments is allowed outside of debugging.
-
-    :param endpoint: my endpoint, as seen by the group leader
-    :param expiration: the time after which the group should begin allreduce or be disbanded
-    :param tensors: a sequence of torch tensors that i intend to average with peers
-    """
-    compression_type = runtime_pb2.NONE
-
-    def __init__(self, endpoint: Endpoint, expiration: DHTExpiration, tensors: Sequence[torch.Tensor]):
-        assert all(tensor.dtype == torch.float32 and tensor.device == torch.device('cpu') for tensor in tensors)
-        self.local_tensors = tensors
-        self.state = ProtocolState.LOOKING_FOR_GROUP
-        self.info = averaging_pb2.PeerInfo(endpoint=endpoint, expiration=expiration,
-                                           schema_hash=compute_schema_hash(tensors))
-
-        self.leader_endpoint: Optional[Endpoint] = None
-        self.group_id: Optional[GroupID] = None  # a unique identifier of this one group all-reduce
-        self.max_size = float('inf')  # maximum group size, only enforced for group leader
-
-        # populated when assembling a group
-        self.group_endpoints_set: Set[Endpoint] = set()
-        self.assembled_group: asyncio.Future[Sequence[Endpoint]] = asyncio.Future()  # final ordered endpoints
-        self.concurrent_requests_lock = asyncio.Lock()  # lock inbound/outbound requests to join group
-
-        # populated when running allreduce
-        self.accumulator: Optional[torch.Tensor] = None   # the sum of averaged tensors so far, init with zeros
-        self.accumulated_from: Set[Endpoint] = set()      # peers that we have accumulated our part from
-        self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()
-
-        self.average_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers
-        self.averaged_tensors: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()
-
-    def __repr__(self):
-        return f"{self.__class__.__name__}({self.info.endpoint}, {self.state})"
-
-    def __await__(self):
-        return self.averaged_tensors.__await__()
-
-    def start_new_group(self, max_size: Optional[int] = None):
-        """ Create new group with a random id, become its leader and the only participant """
-        assert self.state == ProtocolState.LOOKING_FOR_GROUP
-        self.group_id = DHTID.generate().to_bytes()
-        # note: we generate group_id as DHTID for convenience. Do not assume that it has DHTID-like properties
-        logger.debug(f"{self} - starting a new group as a leader. Group id: {self.group_id}")
-        self.state = ProtocolState.LEADER_WAITING_FOR_PEERS
-        self.leader_endpoint = self.info.endpoint
-        self.group_endpoints_set = {self.info.endpoint}
-        if max_size is not None:
-            self.max_size = max_size
-
-    @property
-    def group_size(self):
-        assert self.state in (ProtocolState.LEADER_WAITING_FOR_PEERS, ProtocolState.RUNNING_ALLREDUCE)
-        return len(self.group_endpoints_set)
-
-    def join_group(self, leader_endpoint: Endpoint, group_id: GroupID):
-        """ After you were accepted by a leader, create your local instance using the metadata he sent """
-        self.group_id, self.leader_endpoint = group_id, leader_endpoint
-        logger.debug(f"{self} - joining the group of {leader_endpoint}. Group id: {self.group_id}")
-        self.state = ProtocolState.FOLLOWER_WAITING_FOR_LEADER
-
-    def add_peer_to_group(self, follower: Endpoint):
-        """ Add peer to a group, assuming that he can be added (self.get_reasons_to_reject(peer) is None) """
-        assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS
-        assert follower not in self.group_endpoints_set
-        self.group_endpoints_set.add(follower)
-        logger.debug(f"{self} - adding {follower} to my group. New size = {self.group_size}")
-        if self.group_size > self.max_size:
-            logger.warning(f"{self} - group size ({self.group_size}) exceeded max size ({self.max_size})")
-
-    def remove_peer_from_group(self, follower: Endpoint):
-        """ Remove a disconnected peer from current group """
-        assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS
-        assert follower in self.group_endpoints_set and follower != self.leader_endpoint
-        self.group_endpoints_set.remove(follower)
-        logger.info(f"{self} - removed {follower} from the group. New size = {self.group_size}")
-
-    def disband_group(self):
-        assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS and self.group_size == 1
-        logger.info(f"{self} - disbanded group (reason = empty)")
-        self.state = ProtocolState.LOOKING_FOR_GROUP
-
-    def leader_begin_allreduce(self) -> averaging_pb2.MessageFromLeader:
-        """ As a leader, distribute allreduce metadata to peers and start allreduce """
-        assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS and self.group_size > 1
-        logger.debug(f"{self} - initiating allreduce for {self.group_endpoints_set} peers.")
-        ordered_group_endpoints = list(self.group_endpoints_set)
-        random.shuffle(ordered_group_endpoints)
-        self.assembled_group.set_result(ordered_group_endpoints)
-        self.state = ProtocolState.RUNNING_ALLREDUCE
-
-    def follower_begin_allreduce(self, ordered_group_endpoints: Sequence[Endpoint]):
-        """ As a follower, receive the final list of peers from the leader and begin sending data around """
-        assert self.state == ProtocolState.FOLLOWER_WAITING_FOR_LEADER and self.info.endpoint in ordered_group_endpoints
-        logger.debug(f"{self} - received peer order from the leader, beginning allreduce.")
-        self.group_endpoints_set = set(ordered_group_endpoints)
-        self.assembled_group.set_result(ordered_group_endpoints)
-        self.state = ProtocolState.RUNNING_ALLREDUCE
-
-    async def accumulate(self, source: Endpoint, part: torch.Tensor) -> torch.Tensor:
-        """ Add vector part to accumulator, wait for all other vectors to be added, return the average """
-        assert source not in self.accumulated_from, "duplicate endpoint, already received that part"
-        assert self.accumulator is None or self.accumulator.shape == part.shape
-        logger.debug(f"{self} - accumulated part from {source}")
-
-        self.accumulator = part if self.accumulator is None else self.accumulator.add_(part)
-        self.accumulated_from.add(source)
-
-        ordered_group_endpoints = await self.assembled_group
-        assert len(self.accumulated_from) <= len(ordered_group_endpoints)
-        if len(self.accumulated_from) == len(ordered_group_endpoints):
-            self.averaged_part.set_result(self.accumulator.div_(len(self.accumulated_from)))
-
-        return await self.averaged_part
-
-    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
-        return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
-
-    async def handle_join_request(self, request: averaging_pb2.PeerInfo
-                                  ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
-        """ accept or reject a join request; if accepted, run him through allreduce steps """
-        should_remove_peer = False
-        try:
-            # stage 1: check if there is a reason to reject a peer outright
-            if not is_valid_join_request(request):
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
-                return
-            if self.info.expiration > (request.expiration or float('inf')):
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
-            elif request.schema_hash != self.info.schema_hash:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
-                return
-            elif request.endpoint == self.info.endpoint or request.endpoint in (self.group_endpoints_set or ()):
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
-                return
-            elif self.state == ProtocolState.FOLLOWER_WAITING_FOR_LEADER:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER,
-                                                      suggested_leader=self.leader_endpoint)
-                return
-            elif self.state == ProtocolState.RUNNING_ALLREDUCE or len(self.accumulated_from) > 0:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ALREADY_RUNNING)
-                return
-            if self.state == ProtocolState.LEADER_WAITING_FOR_PEERS and self.group_size >= self.max_size:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
-                return
-
-            # stage 2: add peer to group, optionally start a new one
-            async with self.concurrent_requests_lock:
-                if self.state == ProtocolState.LOOKING_FOR_GROUP:
-                    self.start_new_group()
-
-                assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS
-
-                self.add_peer_to_group(request.endpoint)
-                should_remove_peer = True
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED, group_id=self.group_id)
-
-            if self.group_size >= self.max_size:
-                self.leader_begin_allreduce()
-
-            # stage 3: wait for the group to be assembled and return
-            ordered_group_endpoints = await self.assembled_group
-            if ordered_group_endpoints is not None:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.BEGIN_ALLREDUCE,
-                                                      ordered_group_endpoints=ordered_group_endpoints)
-            else:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
-
-        except Exception as e:
-            logger.exception(e)
-            yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
-
-        finally:  # this code is guaranteed to run if the iterator is destroyed prematurely
-            if should_remove_peer:
-                self.remove_peer_from_group(request.endpoint)
-                if self.group_size <= 1:
-                    self.set_exception(ValueError("All peers have left"))
-
-    async def request_join_group(self, leader: Endpoint
-                                 ) -> Optional[grpc.aio.UnaryStreamCall[averaging_pb2.MessageFromLeader]]:
-        """ request a given peer to be your leader for allreduce. if accepted, return a grpc stream """
-        assert self.state == ProtocolState.LOOKING_FOR_GROUP
-        try:
-            async with self.concurrent_requests_lock:
-                stream = self._get_peer_stub(leader).rpc_group_allreduce(self.info)
-                message = await stream.read()
-                logger.debug(f"{self} - requested {leader} to be my leader, received "
-                             f"{averaging_pb2.MessageCode.Name(message.code)}")
-                if message.code == averaging_pb2.ACCEPTED:
-                    self.join_group(leader, message.group_id)
-                    return stream
-
-        except Exception as e:
-            self.set_exception(e)
-
-    async def wait_for_allreduce(self, stream: grpc.aio.UnaryStreamCall[averaging_pb2.MessageFromLeader]) -> bool:
-        """ the second part of request_join_group, return True if started allreduce, False if failed or disbanded """
-        try:
-            message = await stream.read()
-            if message.code == averaging_pb2.BEGIN_ALLREDUCE:
-                logger.debug(f"{self} - leader triggered allreduce")
-                assert all(isinstance(p, Endpoint) for p in message.ordered_group_endpoints)
-                self.follower_begin_allreduce(message.ordered_group_endpoints)
-                return True
-            else:
-                logger.debug(f"{self} - leader sent {averaging_pb2.MessageCode.Name(message.code)}, leaving group")
-                self.state = ProtocolState.GROUP_DISBANDED
-                return False
-        except Exception as e:
-            self.set_exception(e)
-            return False
-
-    async def run_allreduce(self) -> Sequence[torch.Tensor]:
-        """ send allreduce requests to all peers and collect results, return the averaged tensor """
-        assert self.state == ProtocolState.RUNNING_ALLREDUCE
-        ordered_group_endpoints = await self.assembled_group
-        ordered_local_parts = split_into_parts(self.local_tensors, group_size=self.group_size)
-
-        async def send_part(peer_endpoint: Endpoint, local_part: torch.Tensor):
-            if peer_endpoint == self.info.endpoint:
-                self.average_tensor_parts[peer_endpoint] = await self.accumulate(peer_endpoint, local_part)
-            else:
-                serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
-                response = await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
-                    group_id=self.group_id, endpoint=self.info.endpoint, tensor_part=serialized_tensor_part))
-
-                if response.code == averaging_pb2.ACCEPTED:
-                    self.average_tensor_parts[peer_endpoint] = deserialize_torch_tensor(response.tensor_part)
-                else:
-                    raise ValueError(f"peer {peer_endpoint} replied {averaging_pb2.MessageCode.Name(response.code)}")
-
-            if len(self.average_tensor_parts) >= len(self.group_endpoints_set):
-                ordered_parts = [self.average_tensor_parts[peer] for peer in ordered_group_endpoints]
-                tensor_shapes = [tensor.shape for tensor in self.local_tensors]
-                self.averaged_tensors.set_result(restore_from_parts(ordered_parts, tensor_shapes))
-
-        try:
-            await asyncio.gather(*map(send_part, ordered_group_endpoints, ordered_local_parts))
-            return await self.averaged_tensors
-        except Exception as e:
-            code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
-
-            async def send_error_to_peer(peer_endpoint):
-                await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
-                    group_id=self.group_id, endpoint=self.info.endpoint, code=code))
-            for peer_endpoint in ordered_group_endpoints:
-                asyncio.create_task(send_error_to_peer(peer_endpoint))
-            if code == averaging_pb2.CANCELLED:
-                self.cancel()
-            else:
-                self.set_exception(e)
-            raise
-
-    async def handle_accumulate_request(self, request: averaging_pb2.AveragingData) -> averaging_pb2.AveragingData:
-        """ respond to an incoming rpc_accumulate_part """
-        if self.state not in (ProtocolState.RUNNING_ALLREDUCE, ProtocolState.FOLLOWER_WAITING_FOR_LEADER):
-            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
-        elif request.group_id != self.group_id:
-            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
-        elif request.endpoint in self.accumulated_from:
-            return averaging_pb2.AveragingData(code=averaging_pb2.DUPLICATE_ENDPOINT)
-
-        if request.code in (averaging_pb2.INTERNAL_ERROR, averaging_pb2.CANCELLED):
-            self.set_exception(ValueError(f"{request.endpoint} sent {averaging_pb2.MessageCode.Name(request.code)}"))
-            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
-
-        try:
-            received_part = deserialize_torch_tensor(request.tensor_part)
-            averaged_part = await self.accumulate(request.endpoint, received_part)
-            serialized = serialize_torch_tensor(averaged_part, request.tensor_part.compression, allow_inplace=False)
-            return averaging_pb2.AveragingData(code=averaging_pb2.ACCEPTED, tensor_part=serialized)
-        except asyncio.CancelledError:
-            self.cancel()
-            return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
-        except Exception as e:
-            self.set_exception(e)
-            return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
-
-    def cancel(self):
-        logger.debug(f"{self} - cancelled")
-        self.state = ProtocolState.CANCELLED
-        for future in self.assembled_group, self.averaged_part, self.averaged_tensors:
-            future.cancel()
-
-    def set_exception(self, exception: Exception):
-        logger.debug(f"{self} - {exception}")
-        self.state = ProtocolState.ERROR
-        for future in self.assembled_group, self.averaged_part, self.averaged_tensors:
-            future.set_exception(exception)
-
-
-def split_into_parts(tensors: Sequence[torch.Tensor], group_size: int) -> Tuple[torch.Tensor]:
-    """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
-    flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
-    chunk_slices = torch.linspace(start=0, end=len(flat_tensor), steps=group_size + 1, dtype=torch.int64)
-    chunk_slices[-1] = len(flat_tensor)
-    return tuple(torch.as_tensor(flat_tensor[chunk_slices[i]: chunk_slices[i + 1]]) for i in range(group_size))
-
-
-def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:
-    """ restores the original tensor shapes from chunks obtained by split_into_chunks """
-    flat_tensor = torch.cat(list(chunks))
-    result_sizes = tuple(map(torch.Size.numel, shapes))
-    flat_original_tensors = torch.split_with_sizes(flat_tensor, result_sizes)
-    return tuple(map(torch.Tensor.reshape, flat_original_tensors, shapes))
-
-
-def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
-    """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
-    schema_dicts = [{field_name: str(field_value)
-                    for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
-                    for tensor in tensors]
-    return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()
-
-
-def is_valid_join_request(request: averaging_pb2.PeerInfo) -> bool:
-    assert len(request.ListFields()) == 3, "this function assumes JoinRequest has three fields, it should be updated"
-    return (isinstance(request.schema_hash, bytes) and
-            isinstance(request.expiration, DHTExpiration) and
-            isinstance(request.endpoint, Endpoint))

+ 0 - 185
hivemind/client/averager.py

@@ -1,185 +0,0 @@
-""" A background process that averages your tensors with peers """
-
-from __future__ import annotations
-
-import ctypes
-from typing import Sequence, Optional, Tuple, Any, Union, Awaitable, Dict
-from concurrent.futures.thread import ThreadPoolExecutor
-import multiprocessing as mp
-import asyncio
-
-import torch
-import uvloop
-import grpc
-
-import hivemind
-from hivemind.dht import get_dht_time, DHTExpiration
-from hivemind.utils import get_logger, Endpoint, Port, MPFuture
-from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
-from hivemind.client.allreduce import GroupAllReduce, GroupID
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc
-
-logger = get_logger(__file__)
-
-
-class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
-    """
-    **Warning!** Decentralized averager is in active development, some critical functionality is still underway
-
-    Gating function averaging service. A trainer can run this service in background to periodically average his gating
-    function with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
-    group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
-    Why averaging is valid: see https://github.com/learning-at-home/hivemind/issues/95#issuecomment-688806705
-    On global convergence: see https://github.com/learning-at-home/hivemind/issues/95#issuecomment-717719400
-
-    :param averaged_tensors: a sequence of pytorch tensors that will be averaged in each all-reduce
-    :param dht: a DHT node that will be used to find groups
-    :param start: if True, starts the background process immediately
-    :param timeout: consider allreduce failed if there was no activity for this many **seconds**
-    :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
-            if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
-    :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
-    :param receiver_threads: uses this many threads to await on input pipe. Default = 1 should be enough in most cases
-    :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
-          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
-    :param kwargs: extra parameters forwarded to in grpc.aio.server
-    You can perform averaging using DecentralizedOptimizer (see below) or by manually running each step as such:
-
-    >> TODO add a working example
-    """
-
-    def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
-                 max_size: int = None, timeout: float = 15, listen: bool = True, listen_on: Endpoint = '0.0.0.0:*',
-                 receiver_threads: int = 1, channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
-        super().__init__()
-        self.dht = dht
-        self.server_opts = listen, listen_on, receiver_threads, kwargs
-        self.max_size = max_size if max_size is not None else float('inf')
-        self.timeout = timeout
-        self.channel_options = channel_options
-        self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
-        self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
-        self._pending_groups: Dict[GroupID, GroupAllReduce] = {}
-        self._lock_forming_a_group: Optional[asyncio.Lock] = None
-        self.ready = mp.Event()
-
-        self.averaged_tensors = tuple(averaged_tensors)
-        for tensor in self.averaged_tensors:
-            assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
-            tensor.share_memory_()
-
-        if start:
-            self.run_in_background(await_ready=True)
-
-    @property
-    def port(self) -> Optional[Port]:
-        return self._port.value if self._port.value != 0 else None
-
-    def run(self):
-        """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
-        if asyncio.get_event_loop().is_running():
-            asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
-
-        uvloop.install()
-        loop = asyncio.new_event_loop()
-        asyncio.set_event_loop(loop)
-
-        listen, listen_on, receiver_threads, server_kwargs = self.server_opts
-        pipe_awaiter = ThreadPoolExecutor(receiver_threads)
-        self._lock_forming_a_group = asyncio.Lock()
-
-        async def _run():
-            if listen:
-                grpc.aio.init_grpc_aio()
-                server = grpc.aio.server(**server_kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-                averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
-                found_port = server.add_insecure_port(listen_on)
-                assert found_port != 0, f"Failed to listen to {listen_on}"
-                self._port.value = found_port
-                await server.start()
-                self.ready.set()
-            else:
-                raise NotImplementedError("Client-only averaging is not implemented yet.")
-
-            while True:
-                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
-                asyncio.create_task(getattr(self, method)(*args, **kwargs))
-
-        loop.run_until_complete(_run())
-
-    def run_in_background(self, await_ready=True, timeout=None):
-        """
-        Starts averager in a background process. if await_ready, this method will wait until background dht
-        is ready to process incoming requests or for :timeout: seconds max.
-        """
-        self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
-
-    def shutdown(self) -> None:
-        """ Shut down the averager process """
-        # TODO notify peers before terminating
-        if self.is_alive():
-            self.terminate()
-        else:
-            logger.warning("DHT shutdown has no effect: the process is not alive")
-
-    def group_allreduce(self, my_endpoint: Endpoint, leader_endpoint: Optional[Endpoint] = None,
-                        return_future=False) -> Union[Sequence[torch.Tensor], Awaitable[Sequence[torch.Tensor]]]:
-        """
-        Set up the averager to look for a group and run all-reduce once, optionally await and return outcome
-
-        :note: this function implemented for debugging and will be removed in future versions
-        :param my_endpoint: public endpoint of this averager
-        :param leader_endpoint: if specified, attempts to join this peer's group
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        """
-        expiration = get_dht_time() + self.timeout
-        assert isinstance(expiration, DHTExpiration)
-
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_group_allreduce', [], dict(my_endpoint=my_endpoint, expiration=expiration,
-                                                     leader_endpoint=leader_endpoint, future=_future)))
-        return future if return_future else future.result()
-
-    async def _group_allreduce(self, *, my_endpoint: Endpoint, expiration: DHTExpiration,
-                               leader_endpoint: Optional[Endpoint], future: MPFuture):
-        group_allreduce = GroupAllReduce(my_endpoint, expiration, self.averaged_tensors)
-        try:
-            if leader_endpoint is None:
-                async with self._lock_forming_a_group:
-                    group_allreduce.start_new_group(max_size=self.max_size)
-                    self._forming_group = self._pending_groups[group_allreduce.group_id] = group_allreduce
-                    await asyncio.wait_for(group_allreduce.assembled_group, expiration - get_dht_time())
-
-                future.set_result(await group_allreduce.run_allreduce())
-            else:
-                async with self._lock_forming_a_group:
-                    stream = await group_allreduce.request_join_group(leader_endpoint)
-                    self._forming_group = self._pending_groups[group_allreduce.group_id] = group_allreduce
-
-                started_allreduce = await group_allreduce.wait_for_allreduce(stream)
-                if started_allreduce:
-                    future.set_result(await group_allreduce.run_allreduce())
-                else:
-                    future.set_exception(ValueError(f"Rejected by {leader_endpoint}"))
-
-        except Exception as e:
-            future.set_exception(e)
-        finally:
-            _ = self._pending_groups.pop(group_allreduce.group_id, None)
-            if group_allreduce is self._forming_group:
-                self._forming_group = None
-
-    async def rpc_group_allreduce(self, request: averaging_pb2.PeerInfo, context: grpc.ServicerContext):
-        """ A peer wants me to be his leader. I will coordinate his actions with the rest of my group. Maybe. """
-        if self._forming_group is None:
-            yield averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
-            return
-        async for message in self._forming_group.handle_join_request(request):
-            yield message
-
-    async def rpc_aggregate_part(self, request: averaging_pb2.AveragingData, context: grpc.ServicerContext):
-        if request.group_id not in self._pending_groups:
-            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
-        return await self._pending_groups[request.group_id].handle_accumulate_request(request)

+ 216 - 0
hivemind/client/averaging/__init__.py

@@ -0,0 +1,216 @@
+""" A background process that averages your tensors with peers """
+
+from __future__ import annotations
+
+import random
+import ctypes
+from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
+from concurrent.futures.thread import ThreadPoolExecutor
+import multiprocessing as mp
+import asyncio
+
+import torch
+import uvloop
+import grpc
+
+import hivemind
+from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
+from hivemind.client.averaging.matchmaking import Matchmaking
+from hivemind.utils import get_logger, Endpoint, Port, MPFuture, replace_port, GRPC_KEEPALIVE_OPTIONS
+from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
+
+# flavour types
+StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
+
+INITIAL_GROUP_NBITS = 3
+logger = get_logger(__file__)
+
+
+class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
+    """
+    **Warning!** Decentralized averager is in active development, some critical functionality is still underway
+
+    Parameter averaging service. A trainer can run this service in background to periodically average his parameters
+    with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
+    group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
+
+    :param averaged_tensors: a sequence of pytorch tensors that will be averaged in each all-reduce
+    :param dht: a DHT node that will be used to find groups
+    :param start: if True, starts the background process immediately
+
+    :param prefix: a shared prefix for all group keys
+    :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
+    :param initial_group_bits: a string of bits ('0' and '1') that define initial group key (bucket index)
+      by default, sample a random bit sequence of length {INITIAL_GROUP_NBITS}
+    :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
+      note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
+    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
+    :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
+
+    :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
+            if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
+    :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
+    :param receiver_threads: uses this many threads to await on input pipe. Default = 1 should be enough in most cases
+    :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
+          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
+    :param kwargs: extra parameters forwarded to grpc.aio.server
+    You can perform averaging using DecentralizedOptimizer (see below) or by manually running each step as such:
+
+    >> TODO add a working example here
+    """
+    _matchmaking: Matchmaking
+    _pending_group_assembled: asyncio.Event
+
+    def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
+                 prefix: str, target_group_size: int, min_group_size: int = 1, initial_group_bits: Optional[str] = None,
+                 averaging_expiration: float = 15, allreduce_timeout: Optional[float] = None,
+                 compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+                 listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1,
+                 channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
+        assert '.' not in prefix, "group prefix must be a string without ."
+        if is_power_of_two(target_group_size):
+            logger.warning("It is recommended to set target_group_size to a power of 2.")
+        if initial_group_bits is None:
+            initial_group_bits = ''.join(random.choices('01', k=INITIAL_GROUP_NBITS))
+            logger.debug(f"Initializing with random {INITIAL_GROUP_NBITS}-bit group index: {initial_group_bits}")
+        assert len(initial_group_bits) >= INITIAL_GROUP_NBITS and all(bit in '01' for bit in initial_group_bits)
+
+        super().__init__()
+        self.dht = dht
+        self.listen_on, self.receiver_threads, self.kwargs = listen_on, receiver_threads, kwargs
+        self.channel_options = channel_options
+        self.averaged_tensors = tuple(averaged_tensors)
+        # TODO use mp.Lock to prevent someone from modifying tensors before we copy them! maybe.
+        for tensor in self.averaged_tensors:
+            assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
+            tensor.share_memory_()
+
+        self.matchmaking_kwargs = dict(prefix=prefix, initial_group_bits=initial_group_bits,
+                                       target_group_size=target_group_size, min_group_size=min_group_size,
+                                       averaging_expiration=averaging_expiration)
+        self.allreduce_timeout, self.compression_type = allreduce_timeout, compression_type
+        self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
+
+        self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
+        self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
+        self._averager_endpoint: Optional[Endpoint] = None
+        self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
+
+        if start:
+            self.run_in_background(await_ready=True)
+
+    @property
+    def port(self) -> Optional[Port]:
+        return self._port.value if self._port.value != 0 else None
+
+    @property
+    def endpoint(self) -> Endpoint:
+        if self._averager_endpoint is None:
+            self._averager_endpoint = replace_port(self.listen_on, self.port if self.port is not None else '*')
+            logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
+        return self._averager_endpoint
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({self.endpoint})"
+
+    def run(self):
+        """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
+        if asyncio.get_event_loop().is_running():
+            asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
+
+        uvloop.install()
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(loop)
+
+        # initialize asyncio synchronization primitives in this event loop
+        pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
+
+        async def _run():
+            grpc.aio.init_grpc_aio()
+            server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
+            averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
+            found_port = server.add_insecure_port(self.listen_on)
+            assert found_port != 0, f"Failed to listen to {self.listen_on}"
+            self._port.value = found_port
+            self._matchmaking = Matchmaking(self.endpoint, self.averaged_tensors, self.dht, **self.matchmaking_kwargs)
+            self._pending_group_assembled = asyncio.Event()
+            self._pending_group_assembled.set()
+            await server.start()
+            self.ready.set()
+
+            while True:
+                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
+                asyncio.create_task(getattr(self, method)(*args, **kwargs))
+
+        loop.run_until_complete(_run())
+
+    def run_in_background(self, await_ready=True, timeout=None):
+        """
+        Starts averager in a background process. if await_ready, this method will wait until background dht
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready and not self.ready.wait(timeout=timeout):
+            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
+
+    def shutdown(self) -> None:
+        """ Shut down the averager process """
+        # TODO notify peers before terminating
+        if self.is_alive():
+            self.terminate()
+        else:
+            logger.warning("DHT shutdown has no effect: the process is not alive")
+
+    def step(self, timeout: Optional[float] = None, return_future=False) -> Union[Sequence[torch.Tensor], MPFuture]:
+        """
+        Set up the averager to look for a group and run one round of averaging, then return the averaged tensors
+
+        :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_step', [], dict(future=_future, timeout=timeout)))
+        return future if return_future else future.result()
+
+    async def _step(self, *, future: MPFuture, timeout: Optional[float]):
+        group_id = None
+        try:
+            self._pending_group_assembled.clear()
+            allreduce_group = await self._matchmaking.look_for_group(timeout=timeout)
+            group_id = allreduce_group.group_id
+            if allreduce_group is not None:
+                self._running_groups[group_id] = allreduce_group
+                self._pending_group_assembled.set()
+                future.set_result(await asyncio.wait_for(allreduce_group.run(), self.allreduce_timeout))
+            else:
+                raise AllreduceException(f"{self} - group_allreduce failed, unable to find a group")
+
+        except Exception as e:
+            future.set_exception(e)
+            raise
+        finally:
+            self._pending_group_assembled.set()
+            if group_id is not None:
+                _ = self._running_groups.pop(group_id, None)
+
+    async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
+                             ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
+        """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """
+        async for response in self._matchmaking.rpc_join_group(request, context):
+            yield response
+
+    async def rpc_aggregate_part(self, request: averaging_pb2.AveragingData, context: grpc.ServicerContext):
+        """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
+        if request.group_id not in self._running_groups and not self._pending_group_assembled.is_set():
+            # this handles a special case when leader accepted us to group AND began allreduce right away,
+            # but his response with group_id was delayed and other peers got to us first
+            await self._pending_group_assembled.wait()
+        if request.group_id not in self._running_groups:
+            return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+        else:
+            return await self._running_groups[request.group_id].rpc_aggregate_part(request, context)
+
+
+def is_power_of_two(n):
+    """ Check whether n is a power of 2 """
+    return (n != 0) and (n & (n - 1) == 0)

+ 184 - 0
hivemind/client/averaging/allreduce.py

@@ -0,0 +1,184 @@
+import asyncio
+from typing import Sequence, Set, Dict, Tuple
+
+import grpc
+import torch
+
+from hivemind.utils import Endpoint, get_logger, serialize_torch_tensor, deserialize_torch_tensor, ChannelCache
+from hivemind.proto import averaging_pb2_grpc, runtime_pb2, averaging_pb2
+
+# flavour types
+GroupID = bytes
+logger = get_logger(__name__)
+
+
+class AllReduceProtocol:
+    """
+    An internal class that runs butterfly AllReduce in a predefined group of averagers
+
+    :param tensors: local tensors that should be averaged with groupmates
+    :param endpoint: your endpoint, must be included in ordered_group_endpoints
+    :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
+    """
+    def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
+                 ordered_group_endpoints: Sequence[Endpoint]):
+        assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
+        self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
+        self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, self.group_size)))
+        self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
+
+        self.accumulator = self.local_tensor_parts[self.endpoint].clone()  # sum inputs from peers to this tensor
+        self.accumulated_from: Set[Endpoint] = {self.endpoint}  # peers that we have accumulated our part from
+        self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
+        self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
+        self.averaged_tensors: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
+
+    def __await__(self):
+        return self.averaged_tensors.__await__()
+
+    @property
+    def group_size(self):
+        return len(self.ordered_group_endpoints)
+
+    async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor) -> torch.Tensor:
+        """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
+        assert not self.averaged_part.done(), f"already finished averaging part: {self.averaged_part}"
+        assert not self.averaged_tensors.done(), f"already finished allreduce: {self.averaged_tensors}"
+        assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
+        assert source not in self.accumulated_from, "duplicate source, already received that part"
+        logger.debug(f"{self} - accumulating tensor part from {source}")
+
+        self.accumulator.add_(remote_part)
+        self.accumulated_from.add(source)
+
+        assert len(self.accumulated_from) <= self.group_size
+        if len(self.accumulated_from) == len(self.local_tensor_parts):
+            average_result = self.accumulator.div_(len(self.accumulated_from))
+            self.register_averaged_part(self.endpoint, average_result)
+            self.averaged_part.set_result(average_result)
+
+        return await self.averaged_part
+
+    def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
+        assert not self.averaged_tensors.done(), f"already finished allreduce: {self.averaged_tensors}"
+        assert source in self.local_tensor_parts, "the provider of averaged part is not from my group"
+        assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
+        assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
+        assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
+        logger.debug(f"{self} - receiving averaged tensor part from {source}")
+        self.averaged_tensor_parts[source] = averaged_part
+        if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
+            ordered_averaged_parts = [self.averaged_tensor_parts[endpoint] for endpoint in self.ordered_group_endpoints]
+            self.averaged_tensors.set_result(restore_from_parts(ordered_averaged_parts, self.tensor_shapes))
+
+    def cancel(self) -> bool:
+        if not self.averaged_tensors.done():
+            logger.debug(f"{self} - cancelled")
+            self.averaged_tensors.cancel()
+            if not self.averaged_part.done():
+                self.averaged_part.cancel()
+            return True
+        else:
+            logger.debug(f"{self} - failed to cancel, allreduce is already finished: {self.averaged_tensors}")
+            return False
+
+    def set_exception(self, exception: Exception) -> bool:
+        if not self.averaged_tensors.done():
+            logger.debug(f"{self} - {exception}")
+            self.averaged_tensors.set_exception(exception)
+            if not self.averaged_part.done():
+                self.averaged_part.cancel()
+            return True
+        else:
+            logger.debug(f"{self} - failed to set {exception}, allreduce already finished: {self.averaged_tensors}")
+            return False
+
+
+class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragingServicer):
+    """
+    A class that implements ButterflyAllReduceProtocol on top of a gRPC servicer
+    """
+    def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
+                 ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType):
+        super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint,
+                         ordered_group_endpoints=ordered_group_endpoints)
+        self.compression_type = compression_type
+
+    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
+        return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+
+    async def _average_one_part(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
+        """ Send one part of local tensors to one groupmate and collect the average for this part """
+        serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
+        response = await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(
+            averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id,
+                                        endpoint=self.endpoint, tensor_part=serialized_tensor_part))
+        if response.code == averaging_pb2.AVERAGED_PART:
+            averaged_part = deserialize_torch_tensor(response.tensor_part)
+            self.register_averaged_part(peer_endpoint, averaged_part)
+            return averaged_part
+        else:
+            raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(response.code)}"
+                                     f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
+                                     f" allreduce failed")
+
+    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
+        await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
+            group_id=self.group_id, endpoint=self.endpoint, code=code))
+
+    async def run(self) -> Sequence[torch.Tensor]:
+        """ send allreduce requests to all peers and collect results, return the averaged tensor """
+        try:
+            await asyncio.gather(self, *(self._average_one_part(peer, part)
+                                         for peer, part in self.local_tensor_parts.items() if peer != self.endpoint))
+            return await self
+        except Exception as e:
+            code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
+            logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
+            self.set_exception(e)
+            for peer_endpoint in self.ordered_group_endpoints:
+                asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
+            raise
+
+    async def rpc_aggregate_part(self, request: averaging_pb2.AveragingData, context: grpc.ServicerContext):
+        """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
+        if request.group_id != self.group_id:
+            return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+
+        if request.code == averaging_pb2.PART_FOR_AVERAGING:
+            try:
+                tensor_part = deserialize_torch_tensor(request.tensor_part)
+                averaged_part = await self.accumulate_part(request.endpoint, tensor_part)
+                serialized = serialize_torch_tensor(averaged_part, request.tensor_part.compression, allow_inplace=False)
+                return averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized)
+            except Exception as e:
+                self.set_exception(e)
+                return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+        else:
+            error_code = averaging_pb2.MessageCode.Name(request.code)
+            logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
+            self.set_exception(AllreduceException(f"peer {request.endpoint} sent {error_code}."))
+            return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+
+
+def split_into_parts(tensors: Sequence[torch.Tensor], group_size: int) -> Tuple[torch.Tensor, ...]:
+    """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
+    flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
+    chunk_slices = torch.linspace(start=0, end=len(flat_tensor), steps=group_size + 1, dtype=torch.int64)
+    chunk_slices[-1] = len(flat_tensor)
+    return tuple(flat_tensor[chunk_slices[i]: chunk_slices[i + 1]] for i in range(group_size))
+
+
+def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:
+    """ restores the original tensor shapes from chunks obtained by split_into_chunks """
+    flat_tensor = torch.cat(tuple(chunks))
+    result_sizes = tuple(map(torch.Size.numel, shapes))
+    flat_original_tensors = torch.split_with_sizes(flat_tensor, result_sizes)
+    return tuple(map(torch.Tensor.reshape, flat_original_tensors, shapes))
+
+
+class AllreduceException(Exception):
+    """ A special exception that is raised when allreduce can't continue normally (e.g. disbanded/bad request/etc) """

+ 394 - 0
hivemind/client/averaging/matchmaking.py

@@ -0,0 +1,394 @@
+""" A background process that averages your tensors with peers """
+
+from __future__ import annotations
+
+import contextlib
+import random
+from dataclasses import asdict
+from math import isfinite
+from typing import Sequence, Optional, AsyncIterator, Set
+import asyncio
+
+import torch
+import grpc
+
+import hivemind
+from hivemind.client.averaging.allreduce import AllReduceRunner, GroupID
+from hivemind.dht import DHTID, DHTExpiration, get_dht_time, GroupKey
+from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, TimedStorage
+from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
+from hivemind.utils.grpc import ChannelCache
+
+
+logger = get_logger(__file__)
+
+
+class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
+    f"""
+    An internal class that is used to form groups of averages for running allreduce
+    See DecentralizedAverager docstring for the detailed description of all parameters
+    """
+
+    def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *,
+                 prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
+                 averaging_expiration: float = 15, compression_type: runtime_pb2.CompressionType = runtime_pb2.NONE):
+        assert '.' not in prefix, "group prefix must be a string without ."
+
+        super().__init__()
+        self.dht, self.endpoint, self.averaged_tensors = dht, endpoint, tuple(averaged_tensors)
+        self.prefix, self.group_bits = prefix, initial_group_bits
+        self.target_group_size, self.min_group_size = target_group_size, min_group_size
+        self.averaging_expiration, self.compression_type = averaging_expiration, compression_type
+
+        self.schema_hash = compute_schema_hash(self.averaged_tensors)
+
+        self.lock_looking_for_group = asyncio.Lock()
+        self.lock_request_join_group = asyncio.Lock()
+        self.cond_notify_followers = asyncio.Condition()
+        self.assembled_group = asyncio.Future()
+
+        self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
+        self.current_followers: Set[Endpoint] = set()  # iff i am a leader, this contains my followers excluding myself
+        self.potential_leaders = PotentialLeaders(self.endpoint, self.dht, self.averaging_expiration)
+
+    @property
+    def is_looking_for_group(self):
+        return self.lock_looking_for_group.locked()
+
+    @property
+    def current_group_key(self) -> GroupKey:
+        return f"{self.prefix}.0b{self.group_bits}"
+
+    def __repr__(self):
+        lfg_status = "looking for group," if self.is_looking_for_group else "not looking for group,"
+        if self.is_looking_for_group:
+            if self.current_leader:
+                lfg_status += f" following {self.current_leader},"
+            if len(self.current_followers):
+                lfg_status += f" leading {len(self.current_followers)} followers,"
+        schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
+        return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
+               f" current key = {self.current_group_key})"
+
+    async def look_for_group(self, *, timeout: Optional[float] = None) -> AllReduceRunner:
+        """
+        :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
+        Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
+        """
+        if self.is_looking_for_group:
+            logger.info("Another look_for_group is already in progress. The current run will be scheduled after"
+                        " the existing group is either assembled or disbanded.")
+        async with self.lock_looking_for_group:
+            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
+            try:
+                return await asyncio.wait_for(self.assembled_group, timeout=timeout)
+            except Exception as e:
+                if len(self.current_followers) > 0:
+                    async with self.lock_request_join_group:
+                        await self.leader_disband_group()
+                self.assembled_group.set_exception(e)
+                raise
+
+            finally:
+                if not request_leaders_task.done():
+                    request_leaders_task.cancel()
+                if self.assembled_group.done():
+                    self.assembled_group = asyncio.Future()
+
+    async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
+        """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
+        end_time = get_dht_time() + timeout if timeout is not None else float('inf')
+        async with self.potential_leaders.begin_search(self.current_group_key, timeout):
+            # TODO update group_bits on success! reduce number of bits on not enough peers.
+            # TODO after allreduce finishes, we may need to ask leader to notify lower keys about this
+            # (so as to fix possible network partitioning if some peers operate on a much smaller nbits)
+            while True:
+                try:
+                    time_to_expiration = self.potential_leaders.declared_expiration_time - get_dht_time()
+                    next_best_leader = await asyncio.wait_for(
+                        self.potential_leaders.pop_next_leader(),
+                        timeout=time_to_expiration if isfinite(time_to_expiration) else None)
+
+                    request_expiration_time = min(self.potential_leaders.declared_expiration_time,
+                                                  end_time, get_dht_time() + self.averaging_expiration)
+                    group = await self.request_join_group(next_best_leader, request_expiration_time)
+                    if group is not None:
+                        return group
+
+                except asyncio.TimeoutError:
+                    async with self.lock_request_join_group:
+                        if len(self.current_followers) >= self.min_group_size:
+                            # the time is up, we have a *good enough* group. run allreduce as is.
+                            return await self.leader_assemble_group()
+                        else:
+                            await self.leader_disband_group()
+                            # TODO maybe adjust grid size
+                            continue
+
+    async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[AllReduceRunner]:
+        """
+        :param leader: request this peer to be your leader for allreduce
+        :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
+        :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
+        :note: this function does not guarantee that your group leader is the same as :leader: parameter
+          The originally specified leader can disband group and redirect us to a different leader
+        """
+        assert self.is_looking_for_group and self.current_leader is None
+        call: Optional[grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]] = None
+        try:
+            async with self.lock_request_join_group:
+                leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+                call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
+                    endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time))
+
+                message = await call.read()
+                if message.code != averaging_pb2.ACCEPTED:
+                    code = averaging_pb2.MessageCode.Name(message.code)
+                    logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
+                    return None
+
+                # else: we were accepted
+                logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
+                self.current_leader = leader
+                if len(self.current_followers) > 0:
+                    await self.leader_disband_group()
+
+            async with self.potential_leaders.pause_search():
+                message = await call.read()
+
+            if message.code == averaging_pb2.BEGIN_ALLREDUCE:
+                async with self.lock_request_join_group:
+                    return await self.follower_assemble_group(leader, message.group_id, message.ordered_group_endpoints)
+            elif message.code == averaging_pb2.GROUP_DISBANDED and bool(message.suggested_leader):
+                logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
+                return await self.request_join_group(message.suggested_leader, expiration_time)
+
+            else:
+                logger.debug(f"{self} - leader sent {averaging_pb2.MessageCode.Name(message.code)}, leaving group")
+                return None
+        finally:
+            self.current_leader = None
+            if call is not None:
+                call.cancel()
+
+    async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
+                             ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
+        """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """
+        try:
+            reason_to_reject = self._check_reasons_to_reject(request)
+            if reason_to_reject is not None:
+                yield reason_to_reject
+                return
+
+            current_group = self.assembled_group  # copy current assembled_group to avoid overwriting
+            async with self.lock_request_join_group:
+                self.current_followers.add(request.endpoint)
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
+
+                if len(self.current_followers) + 1 >= self.target_group_size:
+                    # outcome 1: we have assembled a full group and are ready for allreduce
+                    await self.leader_assemble_group()
+
+            if not current_group.done():
+                try:
+                    async with self.cond_notify_followers:
+                        # wait for the group to be assembled or disbanded
+                        timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time())
+                        await asyncio.wait_for(self.cond_notify_followers.wait(), timeout=timeout)
+                except asyncio.TimeoutError:
+                    async with self.lock_request_join_group:
+                        # outcome 2: the time is up, run allreduce with what we have or disband
+                        if len(self.current_followers) + 1 >= self.min_group_size and self.is_looking_for_group:
+                            await self.leader_assemble_group()
+                        else:
+                            await self.leader_disband_group()
+
+            if self.current_leader is not None:
+                # outcome 3: found by a leader with higher priority, send our followers to him
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED,
+                                                      suggested_leader=self.current_leader)
+                return
+
+            if request.endpoint not in self.current_followers:
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
+                return
+
+            # finally, run allreduce
+            allreduce_group = current_group.result()
+            yield averaging_pb2.MessageFromLeader(
+                code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
+                ordered_group_endpoints=allreduce_group.ordered_group_endpoints)
+
+        except Exception as e:
+            logger.exception(e)
+            yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
+
+        finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
+            self.current_followers.discard(request.endpoint)
+
+    def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> averaging_pb2.MessageFromLeader:
+        """ :returns: if accepted, return None, otherwise return a reason for rejection """
+        if not self.is_looking_for_group:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
+
+        if request.ListFields() == 3 and not isinstance(request.schema_hash, bytes) or len(request.schema_hash) == 0 \
+                or not isinstance(request.expiration, DHTExpiration) or not isfinite(request.expiration) \
+                or not isinstance(request.endpoint, Endpoint) or len(request.endpoint) == 0:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
+
+        elif request.schema_hash != self.schema_hash:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
+        elif self.potential_leaders.declared_group_key is None:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
+        elif self.potential_leaders.declared_expiration_time > (request.expiration or float('inf')):
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
+        elif self.current_leader is not None:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER,
+                                                   suggested_leader=self.current_leader)
+        elif request.endpoint == self.endpoint or request.endpoint in self.current_followers:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
+        elif len(self.current_followers) + 1 >= self.target_group_size:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
+        else:
+            return None
+
+    async def leader_assemble_group(self) -> AllReduceRunner:
+        """ Form up all current followers into a group and prepare to _run_allreduce """
+        assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
+        group_id = DHTID.generate().to_bytes()
+        ordered_group_endpoints = list(self.current_followers)
+        ordered_group_endpoints.append(self.endpoint)
+        random.shuffle(ordered_group_endpoints)
+        logger.debug(f"{self.endpoint} - leader started allreduce with {len(ordered_group_endpoints)} followers.")
+        allreduce_group = AllReduceRunner(
+            group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
+            ordered_group_endpoints=ordered_group_endpoints, compression_type=self.compression_type)
+        self.assembled_group.set_result(allreduce_group)
+        async with self.cond_notify_followers:
+            self.cond_notify_followers.notify_all()
+        return allreduce_group
+
+    async def follower_assemble_group(self, leader: Endpoint, group_id: GroupID,
+                                      ordered_group_endpoints: Sequence[Endpoint]) -> AllReduceRunner:
+        """ Prepare to run allreduce using a list of peers provided by our leader """
+        assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
+        logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
+        assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
+        assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
+        allreduce_group = AllReduceRunner(
+            group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
+            ordered_group_endpoints=ordered_group_endpoints, compression_type=self.compression_type)
+        self.assembled_group.set_result(allreduce_group)
+        async with self.cond_notify_followers:
+            self.cond_notify_followers.notify_all()
+        return allreduce_group
+
+    async def leader_disband_group(self):
+        """ Kick out all followers immediately, optionally direct them to our new leader (if we found one) """
+        assert self.lock_request_join_group.locked()
+        self.current_followers.clear()  # this will cause rpc_join_group to kick all followers out
+        async with self.cond_notify_followers:
+            self.cond_notify_followers.notify_all()
+
+
+class PotentialLeaders:
+    """ An utility class that searches for averagers that could become our leaders """
+    def __init__(self, endpoint: Endpoint, dht: hivemind.DHT, averaging_expiration: DHTExpiration):
+        self.endpoint, self.dht, self.averaging_expiration = endpoint, dht, averaging_expiration
+        self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
+        self.leader_queue = TimedStorage[Endpoint, DHTExpiration]()
+        self.max_assured_time = float('-inf')
+        self.declared_expiration_time = float('inf')
+        self.declared_group_key: Optional[GroupKey] = None
+        self.search_end_time = float('inf')
+
+    @contextlib.asynccontextmanager
+    async def begin_search(self, group_key: GroupKey, timeout: Optional[float]):
+        assert not self.running.is_set(), "already running"
+        self.running.set()
+        self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
+        update_queue_task = asyncio.create_task(self._update_queue_periodically(group_key))
+        declare_averager_task = asyncio.create_task(self._declare_averager_periodically(group_key))
+        try:
+            yield self
+        finally:
+            update_queue_task.cancel()
+            declare_averager_task.cancel()
+            self.running.clear()
+            self.update_triggered.clear()
+            self.update_finished.clear()
+
+    @contextlib.asynccontextmanager
+    async def pause_search(self):
+        was_running = self.running.is_set()
+        try:
+            self.running.clear()
+            yield
+        finally:
+            if was_running:
+                self.running.set()
+            else:
+                self.running.clear()
+
+    async def pop_next_leader(self) -> Endpoint:
+        """ Remove and return the next most suitable leader or throw an exception if reached timeout """
+        assert self.running, "Not running search at the moment"
+        maybe_next_leader, entry = self.leader_queue.top()
+
+        next_entry_time = entry.expiration_time if maybe_next_leader is not None else get_dht_time()
+        if self.max_assured_time < next_entry_time < self.search_end_time:
+            self.update_triggered.set()
+
+        if maybe_next_leader is None:
+            await self.update_finished.wait()
+            return await self.pop_next_leader()
+
+        del self.leader_queue[maybe_next_leader]
+        return maybe_next_leader
+
+    async def _update_queue_periodically(self, group_key: GroupKey):
+        DISCREPANCY = hivemind.utils.timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
+        while get_dht_time() < self.search_end_time:
+            new_peers = await self.dht.get_averagers(group_key, only_active=True, return_future=True)
+            self.max_assured_time = max(self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY)
+
+            for peer, peer_expiration_time in new_peers:
+                if peer == self.endpoint:
+                    continue
+                self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
+                self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
+
+            if len(self.leader_queue) > 0:
+                self.update_finished.set()
+
+            await asyncio.wait(
+                {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
+                timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None)
+            self.update_triggered.clear()
+
+    async def _declare_averager_periodically(self, group_key: GroupKey):
+        try:
+            while True:
+                new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
+                self.declared_group_key, self.declared_expiration_time = group_key, new_expiration_time
+                stored_ok = await self.dht.declare_averager(group_key, self.endpoint, new_expiration_time,
+                                                            looking_for_group=True, return_future=True)
+                if stored_ok:
+                    await asyncio.sleep(self.declared_expiration_time - get_dht_time())
+                else:
+                    logger.warning(f"Failed to subscribe to group {group_key} : store rejected by DHT peers")
+        finally:
+            if self.declared_group_key is not None:
+                previous_declared_key, previous_expiration_time = self.declared_group_key, self.declared_expiration_time
+                self.declared_group_key, self.declared_expiration_time = None, float('inf')
+                self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float('-inf')
+                await self.dht.declare_averager(previous_declared_key, self.endpoint, previous_expiration_time,
+                                                looking_for_group=False, return_future=True)
+
+
+def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
+    """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
+    schema_dicts = [{field_name: str(field_value)
+                    for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
+                    for tensor in tensors]
+    return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()

+ 75 - 0
hivemind/dht/__init__.py

@@ -22,6 +22,7 @@ from concurrent.futures import ThreadPoolExecutor
 from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
 
 import uvloop
+from numpy import nextafter
 
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
@@ -37,16 +38,25 @@ FLAT_EXPERT = -1     # grid prefix reserved for storing 1d expert uids. Used to
 UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
 PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
 #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
+GroupKey = str
+GROUP_PATTERN = re.compile('^(([^.])+)[.]0b[01]+$')  # e.g. bert_exp4_averaging.0b01001101
 
 
 def is_valid_uid(maybe_uid: str) -> bool:
+    """ An uid must contain a string expert type, followed by one or more .-separated numeric indices """
     return bool(UID_PATTERN.fullmatch(maybe_uid))
 
 
 def is_valid_prefix(maybe_prefix: str) -> bool:
+    """ An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period """
     return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
 
 
+def is_valid_group(maybe_group: str) -> bool:
+    """ A group identifier must contain group type, followed by one or more .-separated indices, and any ?metadata"""
+    return bool(GROUP_PATTERN.fullmatch(maybe_group))
+
+
 def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
     """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
@@ -118,6 +128,7 @@ class DHT(mp.Process):
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
                  receiver_threads: int = 1, negative_caching: bool = True, expiration: float = 300, **kwargs):
         super().__init__()
+        assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.receiver_threads, self.max_workers, self.parallel_rpc = receiver_threads, max_workers, parallel_rpc
         self.expiration, self.negative_caching = expiration, negative_caching
@@ -457,3 +468,67 @@ class DHT(mp.Process):
         if future is not None:
             future.set_result(best_experts_batch)
         return best_experts_batch
+
+    def declare_averager(self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float, *,
+                         looking_for_group: bool = True, return_future: bool = False) -> Union[bool, MPFuture]:
+        """
+        Add (or remove) the averager to a given allreduce bucket
+
+        :param group_key: allreduce group key, e.g. my_averager.0b011011101
+        :param endpoint: averager public endpoint for incoming requests
+        :param expiration_time: intent to run allreduce before this timestamp
+        :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
+          If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
+        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
+        :return: True if declared, False if declaration was rejected by DHT peers
+        :note: when leaving (i.e. is_active=False), please specify the same expiration_time as when entering the group
+        :note: setting is_active=False does *not* guarantee that others will immediately stop to query you.
+        """
+        assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_declare_averager', [],
+                        dict(group_key=group_key, endpoint=endpoint, expiration_time=expiration_time,
+                             looking_for_group=looking_for_group, future=_future)))
+        return future if return_future else future.result()
+
+    async def _declare_averager(self, node: DHTNode, *, group_key: str, endpoint: Endpoint,
+                                expiration_time: DHTExpiration, looking_for_group: bool, future: MPFuture):
+        try:
+            expiration_time = expiration_time if looking_for_group else nextafter(expiration_time, float('inf'))
+            # ^-- when declaring averager inactive, we increment expiration time to overwrite the pre-existing entry
+            store_ok = await node.store(
+                key=group_key, subkey=endpoint, value=looking_for_group, expiration_time=expiration_time)
+            future.set_result(store_ok)
+        except Exception as e:
+            future.set_exception(e)
+
+    def get_averagers(self, group_key: GroupKey, *, only_active: bool = True, return_future: bool = False
+                      ) -> Union[List[Tuple[Endpoint, DHTExpiration]], MPFuture]:
+        """
+        Find and return averagers in a specified all-reduce bucket
+
+        :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
+        :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
+            if False, return all averagers under a given group_key regardless of value
+        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
+        :return: endpoints and expirations of every matching averager
+        """
+        assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_get_averagers', [], dict(group_key=group_key, only_active=only_active, future=_future)))
+        return future if return_future else future.result()
+
+    async def _get_averagers(self, node: DHTNode, *, group_key: str, only_active: bool, future: MPFuture):
+        try:
+            result = await node.get(group_key, latest=True)
+            if result is None:
+                logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
+                future.set_result([])
+                return
+            assert isinstance(result.value, dict), f"expected {group_key} to be a Dict[Endpoint, is_active], " \
+                                                   f"but got {result.value} of type {type(result.value)}."
+            averagers = [(endpoint, entry.expiration_time) for endpoint, entry in result.value.items()
+                         if not only_active or entry.value is True]
+            future.set_result(averagers)
+        except Exception as e:
+            future.set_exception(e)

+ 23 - 20
hivemind/proto/averaging.proto

@@ -4,34 +4,37 @@ import "runtime.proto";
 
 // Runs alongside each trainer to perform gating function averaging every now and then. Read more: client/averaging.py
 service DecentralizedAveraging {
-  rpc rpc_group_allreduce(PeerInfo) returns (stream MessageFromLeader);  // assemble a group and run all-reduce
+  rpc rpc_join_group(JoinRequest) returns (stream MessageFromLeader);  // assemble a group for allreduce
   rpc rpc_aggregate_part(AveragingData) returns (AveragingData);  // send my local shard => get aggregated shard
 }
 
-message PeerInfo {
+enum MessageCode {
+  NO_CODE = 0;               // Default value that should not be used explicitly
+  REQUEST_JOIN = 1;          // "Dear maybe leader, will you have me in your group as a follower?"
+  ACCEPTED = 2;              // "I accept you in my group, you now commit to responding to me"
+  BEGIN_ALLREDUCE = 3;       // "We can begin allreduce now. These are your peers."
+  PART_FOR_AVERAGING = 4;    // "I am running allreduce with you, here's a part of my tensor that you should aggregate"
+  AVERAGED_PART = 5;         // "I aggregated your part with others and here's the average for that part"
+  NOT_DECLARED = 6;          // "I have not declared my group id yet, how the heck did you even find me? Go away."
+  NOT_A_LEADER = 7;          // "I am not a group a leader. Go ask my leader instead."
+  BAD_EXPIRATION_TIME = 8;   // "I will not accept you. I cannot guarantee that we begin before you expire."
+  BAD_SCHEMA_HASH = 9;       // "I will not accept you. I am not averaging the samy type of tensors as you."
+  BAD_GROUP_ID = 10;         // "I will not accept your request, your group id does not match with any groups i'm in."
+  DUPLICATE_ENDPOINT = 11;   // "I will not accept you, i already have exactly the same endpoint in my current group."
+  GROUP_IS_FULL = 12;        // "I will not accept you, my group already contains too many peers."
+  NOT_LOOKING_FOR_GROUP = 13;// "I'm not available at the moment. Please, get lost."
+  PROTOCOL_VIOLATION = 14;   // "You did something so unspeakable that i don't have a special code for that."
+  INTERNAL_ERROR = 15;       // "I messed up, we will have to stop allreduce because of that."
+  CANCELLED = 16;            // "[from peer during allreduce] I no longer want to participate in AllReduce."
+  GROUP_DISBANDED = 17;      // "[from leader] The group is closed. Go find another group."
+}
+
+message JoinRequest {
   string endpoint = 1;          // A follower accepts incoming allreduce requests at this address
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
 }
 
-enum MessageCode {
-  // response to join request
-  ACCEPTED = 0;              // "I accept you in my group, you will not commit to responding to me."
-  NOT_A_LEADER = 1;          // "I am not a group a leader. Go ask my leader instead."
-  ALREADY_RUNNING = 2;       // "My group has already began merging. Here's the group leader."
-  NOT_LOOKING_FOR_GROUP = 3; // "I'm not available at the moment. Please, get lost."
-  BAD_EXPIRATION_TIME = 4;   // "I will not accept you. I cannot guarantee that we begin before you expire."
-  BAD_SCHEMA_HASH = 5;       // "I will not accept you. I am not averaging the samy type of tensors as you."
-  DUPLICATE_ENDPOINT = 6;    // "I will not accept you, i already have exactly the same endpoint in my current group"
-  GROUP_IS_FULL = 7;         // "I will not accept you, my group already contains too many peers"
-  BEGIN_ALLREDUCE = 8;       // "We can begin allreduce now. These are your peers."
-  GROUP_DISBANDED = 9;       // "The group is closed. Go find another group."
-  UNKNOWN_GROUP_ID = 10;     // "Your request uses a group id that doesn't match with any group i know"
-  PROTOCOL_VIOLATION = 11;   // "One of peers did something in violation of the allreduce protocol"
-  INTERNAL_ERROR = 12;       // "We encountered an unexpected error on our side"
-  CANCELLED = 13;            // "A peer cancelled allreduce while averaging"
-}
-
 message MessageFromLeader {
   MessageCode code = 1;
   bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed

+ 1 - 0
hivemind/utils/timed_storage.py

@@ -8,6 +8,7 @@ from typing import TypeVar, NamedTuple, Generic, Optional, Dict, List, Iterator,
 KeyType = TypeVar('KeyType')
 ValueType = TypeVar('ValueType')
 get_dht_time = time.time  # a global (weakly synchronized) time
+MAX_DHT_TIME_DISCREPANCY_SECONDS = 3  # max allowed difference between get_dht_time for two DHT nodes. Enforced when joining DHT.(TODO)
 DHTExpiration = float
 ROOT = 0
 

+ 57 - 55
tests/test_averaging.py

@@ -1,46 +1,62 @@
 import asyncio
 import random
 import time
-from itertools import product
 
 import torch
 import pytest
 import hivemind
-from hivemind.client.allreduce import GroupAllReduce, split_into_parts, restore_from_parts
-from hivemind.utils import LOCALHOST
+from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
+from hivemind.utils import Endpoint
 
 
 @pytest.mark.forked
-@pytest.mark.asyncio
-async def test_allreduce_direct():
-    # WARNING! this test uses an early interface that will change by the time DecentralizedAverager is finished
+def test_getset_averagers():
+    dht = hivemind.DHT(start=True)
+
+    t = hivemind.get_dht_time()
+    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 60)
+    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', expiration_time=t + 61)
+
+    q1 = dht.get_averagers('bucket.0b10110', only_active=True)
+
+    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 66)
+    q2 = dht.get_averagers('bucket.0b10110', only_active=True)
+
+    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', looking_for_group=False,
+                         expiration_time=t + 61)
+    q3 = dht.get_averagers('bucket.0b10110', only_active=True)
+    q4 = dht.get_averagers('bucket.0b10110', only_active=False)
+
+    assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1
+    assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2
+    assert len(q3) == 1 and ('localhvost', t + 66) in q3
+    assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2
+
 
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_allreduce_once():
     dht = hivemind.DHT(start=True)
 
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
+    tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
 
-    reference = [(tensors1[i] + tensors2[i] + tensors3[i]) / 3 for i in range(len(tensors1))]
-
-    averager1 = hivemind.DecentralizedAverager(tensors1, dht=dht, start=True, max_size=3, timeout=5)
-    averager2 = hivemind.DecentralizedAverager(tensors2, dht=dht, start=True, max_size=3, timeout=5)
-    averager3 = hivemind.DecentralizedAverager(tensors3, dht=dht, start=True, max_size=3, timeout=5)
-
-    future1 = averager1.group_allreduce(my_endpoint=f"{LOCALHOST}:{averager1.port}",
-                                        leader_endpoint=None, return_future=True)
-    time.sleep(0.1)
+    reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
 
-    future2 = averager2.group_allreduce(my_endpoint=f"{LOCALHOST}:{averager2.port}",
-                                        leader_endpoint=f"{LOCALHOST}:{averager1.port}",
-                                        return_future=True)
+    averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
+                                                prefix='mygroup', initial_group_bits='0110', listen_on='127.0.0.1:*',
+                                                start=True)
+                 for tensors in [tensors1, tensors2, tensors3, tensors4]]
 
-    future3 = averager3.group_allreduce(my_endpoint=f"{LOCALHOST}:{averager3.port}",
-                                        leader_endpoint=f"{LOCALHOST}:{averager1.port}",
-                                        return_future=True)
+    futures = []
+    for averager in averagers:
+        futures.append(averager.step(return_future=True))  # TODO revert to hard version
+        time.sleep(0.5)
 
-    for future in future1, future2, future3:
-        for ref, our in zip(reference, await future):
+    for future in futures:
+        for ref, our in zip(reference, future.result()):
             assert torch.allclose(ref, our)
 
 
@@ -49,50 +65,36 @@ async def test_allreduce_direct():
 async def test_allreduce_protocol():
     """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
     peers = "alice", "bob", "carol"
-    expiration_offsets = 4, 0, 1
 
     tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
                        for i, peer in enumerate(peers)}
 
-    alice, bob, carol = allreduce_protocols = [
-        GroupAllReduce(endpoint=peer, expiration=hivemind.get_dht_time() + offset, tensors=tensors_by_peer[peer])
-        for peer, offset in zip(peers, expiration_offsets)]
-
-    bob.start_new_group()
-    bob.add_peer_to_group(alice.info.endpoint)
-    alice.join_group(bob, bob.group_id)
-    bob.add_peer_to_group(carol.info.endpoint)
-    carol.join_group(carol, bob.group_id)
-
-    bob.leader_begin_allreduce()
-    ordered_group_endpoints = await bob.assembled_group
-    assert len(ordered_group_endpoints) == len(peers)
-
-    carol.follower_begin_allreduce(ordered_group_endpoints)
-    alice.follower_begin_allreduce(ordered_group_endpoints)
-
-    chunks_by_peer = {protocol.info.endpoint: {
-        peer: part for peer, part in zip(peers, split_into_parts(protocol.local_tensors, len(ordered_group_endpoints)))
-    } for protocol in allreduce_protocols}
+    group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
+    allreduce_protocols = [AllReduceProtocol(
+        group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer], ordered_group_endpoints=peers)
+        for peer in peers]
 
-    all_pairs = list(product(allreduce_protocols, peers))
-    random.shuffle(all_pairs)
-    await asyncio.gather(*(
-        peer_allreduce.accumulate(source_peer, chunks_by_peer[source_peer][peer_allreduce.info.endpoint])
-        for peer_allreduce, source_peer in all_pairs))
+    async def _accumulate(sender: Endpoint, recipient: Endpoint):
+        sender_allreduce = allreduce_protocols[peers.index(sender)]
+        recipient_allreduce = allreduce_protocols[peers.index(recipient)]
+        averaged_part = await recipient_allreduce.accumulate_part(
+            source=sender, remote_part=sender_allreduce.local_tensor_parts[recipient])
+        sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part)
 
-    averaged_parts = await asyncio.gather(*(protocol.averaged_part for protocol in allreduce_protocols))
-    tensor_shapes = [tensor.shape for tensor in alice.local_tensors]
-    averaged_tensors = restore_from_parts(averaged_parts, tensor_shapes)
+    await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers
+                        if sender != recipient})
 
     reference_tensors = [
         sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
         for i in range(len(tensors_by_peer[peers[0]]))
     ]
 
-    assert len(averaged_tensors) == len(reference_tensors)
-    assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
-               for our, ref in zip(averaged_tensors, reference_tensors))
+    for peer, allreduce in zip(peers, allreduce_protocols):
+        assert allreduce.averaged_tensors.done()
+        averaged_tensors = await allreduce
+        assert len(averaged_tensors) == len(reference_tensors)
+        assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
+                   for our, ref in zip(averaged_tensors, reference_tensors))
 
 
 @pytest.mark.forked