Преглед на файлове

Implement averaging parameters over DHT (2/3)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic преди 4 години
родител
ревизия
eb93789ac6

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.19'
+__version__ = '0.8.20'

+ 1 - 1
hivemind/client/__init__.py

@@ -1,3 +1,3 @@
 from hivemind.client.expert import RemoteExpert
 from hivemind.client.expert import RemoteExpert
 from hivemind.client.moe import RemoteMixtureOfExperts
 from hivemind.client.moe import RemoteMixtureOfExperts
-from hivemind.client.averager import DecentralizedAverager
+from hivemind.client.averaging import DecentralizedAverager

+ 0 - 356
hivemind/client/allreduce.py

@@ -1,356 +0,0 @@
-""" This file contains a state machine that defines allreduce protocol used in DecentralizedAverager """
-from __future__ import annotations
-import asyncio
-import random
-from dataclasses import asdict
-from typing import Set, Optional, Sequence, Tuple, Dict, AsyncIterator
-from enum import Enum, auto
-
-import grpc
-import torch
-
-from hivemind.dht import DHTID, DHTExpiration
-from hivemind.utils import Endpoint, get_logger, MSGPackSerializer
-from hivemind.utils import TensorDescriptor, deserialize_torch_tensor, serialize_torch_tensor, ChannelCache
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
-
-logger = get_logger(__name__)
-
-# flavour types
-GroupID = bytes
-
-
-class ProtocolState(Enum):
-    LOOKING_FOR_GROUP = auto()   # i want to run averaging, but haven't found any peers yet
-    LEADER_WAITING_FOR_PEERS = auto()     # i am a leader, waiting for more peers to join
-    FOLLOWER_WAITING_FOR_LEADER = auto()  # i am a follower, my leader is assembling the group
-    RUNNING_ALLREDUCE = auto()   # we are currently exchanging tensors in a group
-    FINISHED_NORMALLY = auto()   # we ran allreduce and finished without errors
-    GROUP_DISBANDED = auto()     # leader disbanded the group before we began allreduce
-    ERROR = auto()               # someone (maybe i) messed up and we can't recover
-    CANCELLED = auto()           # i have unilaterally cancelled GroupAllreduce
-
-
-class GroupAllReduce:
-    """
-    An internal class that keeps track of one group allreduce for DecentralizedAverager.
-    GroupAllReduce is meant to be modified with methods, no direct variable assignments is allowed outside of debugging.
-
-    :param endpoint: my endpoint, as seen by the group leader
-    :param expiration: the time after which the group should begin allreduce or be disbanded
-    :param tensors: a sequence of torch tensors that i intend to average with peers
-    """
-    compression_type = runtime_pb2.NONE
-
-    def __init__(self, endpoint: Endpoint, expiration: DHTExpiration, tensors: Sequence[torch.Tensor]):
-        assert all(tensor.dtype == torch.float32 and tensor.device == torch.device('cpu') for tensor in tensors)
-        self.local_tensors = tensors
-        self.state = ProtocolState.LOOKING_FOR_GROUP
-        self.info = averaging_pb2.PeerInfo(endpoint=endpoint, expiration=expiration,
-                                           schema_hash=compute_schema_hash(tensors))
-
-        self.leader_endpoint: Optional[Endpoint] = None
-        self.group_id: Optional[GroupID] = None  # a unique identifier of this one group all-reduce
-        self.max_size = float('inf')  # maximum group size, only enforced for group leader
-
-        # populated when assembling a group
-        self.group_endpoints_set: Set[Endpoint] = set()
-        self.assembled_group: asyncio.Future[Sequence[Endpoint]] = asyncio.Future()  # final ordered endpoints
-        self.concurrent_requests_lock = asyncio.Lock()  # lock inbound/outbound requests to join group
-
-        # populated when running allreduce
-        self.accumulator: Optional[torch.Tensor] = None   # the sum of averaged tensors so far, init with zeros
-        self.accumulated_from: Set[Endpoint] = set()      # peers that we have accumulated our part from
-        self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()
-
-        self.average_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers
-        self.averaged_tensors: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()
-
-    def __repr__(self):
-        return f"{self.__class__.__name__}({self.info.endpoint}, {self.state})"
-
-    def __await__(self):
-        return self.averaged_tensors.__await__()
-
-    def start_new_group(self, max_size: Optional[int] = None):
-        """ Create new group with a random id, become its leader and the only participant """
-        assert self.state == ProtocolState.LOOKING_FOR_GROUP
-        self.group_id = DHTID.generate().to_bytes()
-        # note: we generate group_id as DHTID for convenience. Do not assume that it has DHTID-like properties
-        logger.debug(f"{self} - starting a new group as a leader. Group id: {self.group_id}")
-        self.state = ProtocolState.LEADER_WAITING_FOR_PEERS
-        self.leader_endpoint = self.info.endpoint
-        self.group_endpoints_set = {self.info.endpoint}
-        if max_size is not None:
-            self.max_size = max_size
-
-    @property
-    def group_size(self):
-        assert self.state in (ProtocolState.LEADER_WAITING_FOR_PEERS, ProtocolState.RUNNING_ALLREDUCE)
-        return len(self.group_endpoints_set)
-
-    def join_group(self, leader_endpoint: Endpoint, group_id: GroupID):
-        """ After you were accepted by a leader, create your local instance using the metadata he sent """
-        self.group_id, self.leader_endpoint = group_id, leader_endpoint
-        logger.debug(f"{self} - joining the group of {leader_endpoint}. Group id: {self.group_id}")
-        self.state = ProtocolState.FOLLOWER_WAITING_FOR_LEADER
-
-    def add_peer_to_group(self, follower: Endpoint):
-        """ Add peer to a group, assuming that he can be added (self.get_reasons_to_reject(peer) is None) """
-        assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS
-        assert follower not in self.group_endpoints_set
-        self.group_endpoints_set.add(follower)
-        logger.debug(f"{self} - adding {follower} to my group. New size = {self.group_size}")
-        if self.group_size > self.max_size:
-            logger.warning(f"{self} - group size ({self.group_size}) exceeded max size ({self.max_size})")
-
-    def remove_peer_from_group(self, follower: Endpoint):
-        """ Remove a disconnected peer from current group """
-        assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS
-        assert follower in self.group_endpoints_set and follower != self.leader_endpoint
-        self.group_endpoints_set.remove(follower)
-        logger.info(f"{self} - removed {follower} from the group. New size = {self.group_size}")
-
-    def disband_group(self):
-        assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS and self.group_size == 1
-        logger.info(f"{self} - disbanded group (reason = empty)")
-        self.state = ProtocolState.LOOKING_FOR_GROUP
-
-    def leader_begin_allreduce(self) -> averaging_pb2.MessageFromLeader:
-        """ As a leader, distribute allreduce metadata to peers and start allreduce """
-        assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS and self.group_size > 1
-        logger.debug(f"{self} - initiating allreduce for {self.group_endpoints_set} peers.")
-        ordered_group_endpoints = list(self.group_endpoints_set)
-        random.shuffle(ordered_group_endpoints)
-        self.assembled_group.set_result(ordered_group_endpoints)
-        self.state = ProtocolState.RUNNING_ALLREDUCE
-
-    def follower_begin_allreduce(self, ordered_group_endpoints: Sequence[Endpoint]):
-        """ As a follower, receive the final list of peers from the leader and begin sending data around """
-        assert self.state == ProtocolState.FOLLOWER_WAITING_FOR_LEADER and self.info.endpoint in ordered_group_endpoints
-        logger.debug(f"{self} - received peer order from the leader, beginning allreduce.")
-        self.group_endpoints_set = set(ordered_group_endpoints)
-        self.assembled_group.set_result(ordered_group_endpoints)
-        self.state = ProtocolState.RUNNING_ALLREDUCE
-
-    async def accumulate(self, source: Endpoint, part: torch.Tensor) -> torch.Tensor:
-        """ Add vector part to accumulator, wait for all other vectors to be added, return the average """
-        assert source not in self.accumulated_from, "duplicate endpoint, already received that part"
-        assert self.accumulator is None or self.accumulator.shape == part.shape
-        logger.debug(f"{self} - accumulated part from {source}")
-
-        self.accumulator = part if self.accumulator is None else self.accumulator.add_(part)
-        self.accumulated_from.add(source)
-
-        ordered_group_endpoints = await self.assembled_group
-        assert len(self.accumulated_from) <= len(ordered_group_endpoints)
-        if len(self.accumulated_from) == len(ordered_group_endpoints):
-            self.averaged_part.set_result(self.accumulator.div_(len(self.accumulated_from)))
-
-        return await self.averaged_part
-
-    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
-        return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
-
-    async def handle_join_request(self, request: averaging_pb2.PeerInfo
-                                  ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
-        """ accept or reject a join request; if accepted, run him through allreduce steps """
-        should_remove_peer = False
-        try:
-            # stage 1: check if there is a reason to reject a peer outright
-            if not is_valid_join_request(request):
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
-                return
-            if self.info.expiration > (request.expiration or float('inf')):
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
-            elif request.schema_hash != self.info.schema_hash:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
-                return
-            elif request.endpoint == self.info.endpoint or request.endpoint in (self.group_endpoints_set or ()):
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
-                return
-            elif self.state == ProtocolState.FOLLOWER_WAITING_FOR_LEADER:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER,
-                                                      suggested_leader=self.leader_endpoint)
-                return
-            elif self.state == ProtocolState.RUNNING_ALLREDUCE or len(self.accumulated_from) > 0:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ALREADY_RUNNING)
-                return
-            if self.state == ProtocolState.LEADER_WAITING_FOR_PEERS and self.group_size >= self.max_size:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
-                return
-
-            # stage 2: add peer to group, optionally start a new one
-            async with self.concurrent_requests_lock:
-                if self.state == ProtocolState.LOOKING_FOR_GROUP:
-                    self.start_new_group()
-
-                assert self.state == ProtocolState.LEADER_WAITING_FOR_PEERS
-
-                self.add_peer_to_group(request.endpoint)
-                should_remove_peer = True
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED, group_id=self.group_id)
-
-            if self.group_size >= self.max_size:
-                self.leader_begin_allreduce()
-
-            # stage 3: wait for the group to be assembled and return
-            ordered_group_endpoints = await self.assembled_group
-            if ordered_group_endpoints is not None:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.BEGIN_ALLREDUCE,
-                                                      ordered_group_endpoints=ordered_group_endpoints)
-            else:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
-
-        except Exception as e:
-            logger.exception(e)
-            yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
-
-        finally:  # this code is guaranteed to run if the iterator is destroyed prematurely
-            if should_remove_peer:
-                self.remove_peer_from_group(request.endpoint)
-                if self.group_size <= 1:
-                    self.set_exception(ValueError("All peers have left"))
-
-    async def request_join_group(self, leader: Endpoint
-                                 ) -> Optional[grpc.aio.UnaryStreamCall[averaging_pb2.MessageFromLeader]]:
-        """ request a given peer to be your leader for allreduce. if accepted, return a grpc stream """
-        assert self.state == ProtocolState.LOOKING_FOR_GROUP
-        try:
-            async with self.concurrent_requests_lock:
-                stream = self._get_peer_stub(leader).rpc_group_allreduce(self.info)
-                message = await stream.read()
-                logger.debug(f"{self} - requested {leader} to be my leader, received "
-                             f"{averaging_pb2.MessageCode.Name(message.code)}")
-                if message.code == averaging_pb2.ACCEPTED:
-                    self.join_group(leader, message.group_id)
-                    return stream
-
-        except Exception as e:
-            self.set_exception(e)
-
-    async def wait_for_allreduce(self, stream: grpc.aio.UnaryStreamCall[averaging_pb2.MessageFromLeader]) -> bool:
-        """ the second part of request_join_group, return True if started allreduce, False if failed or disbanded """
-        try:
-            message = await stream.read()
-            if message.code == averaging_pb2.BEGIN_ALLREDUCE:
-                logger.debug(f"{self} - leader triggered allreduce")
-                assert all(isinstance(p, Endpoint) for p in message.ordered_group_endpoints)
-                self.follower_begin_allreduce(message.ordered_group_endpoints)
-                return True
-            else:
-                logger.debug(f"{self} - leader sent {averaging_pb2.MessageCode.Name(message.code)}, leaving group")
-                self.state = ProtocolState.GROUP_DISBANDED
-                return False
-        except Exception as e:
-            self.set_exception(e)
-            return False
-
-    async def run_allreduce(self) -> Sequence[torch.Tensor]:
-        """ send allreduce requests to all peers and collect results, return the averaged tensor """
-        assert self.state == ProtocolState.RUNNING_ALLREDUCE
-        ordered_group_endpoints = await self.assembled_group
-        ordered_local_parts = split_into_parts(self.local_tensors, group_size=self.group_size)
-
-        async def send_part(peer_endpoint: Endpoint, local_part: torch.Tensor):
-            if peer_endpoint == self.info.endpoint:
-                self.average_tensor_parts[peer_endpoint] = await self.accumulate(peer_endpoint, local_part)
-            else:
-                serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
-                response = await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
-                    group_id=self.group_id, endpoint=self.info.endpoint, tensor_part=serialized_tensor_part))
-
-                if response.code == averaging_pb2.ACCEPTED:
-                    self.average_tensor_parts[peer_endpoint] = deserialize_torch_tensor(response.tensor_part)
-                else:
-                    raise ValueError(f"peer {peer_endpoint} replied {averaging_pb2.MessageCode.Name(response.code)}")
-
-            if len(self.average_tensor_parts) >= len(self.group_endpoints_set):
-                ordered_parts = [self.average_tensor_parts[peer] for peer in ordered_group_endpoints]
-                tensor_shapes = [tensor.shape for tensor in self.local_tensors]
-                self.averaged_tensors.set_result(restore_from_parts(ordered_parts, tensor_shapes))
-
-        try:
-            await asyncio.gather(*map(send_part, ordered_group_endpoints, ordered_local_parts))
-            return await self.averaged_tensors
-        except Exception as e:
-            code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
-
-            async def send_error_to_peer(peer_endpoint):
-                await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
-                    group_id=self.group_id, endpoint=self.info.endpoint, code=code))
-            for peer_endpoint in ordered_group_endpoints:
-                asyncio.create_task(send_error_to_peer(peer_endpoint))
-            if code == averaging_pb2.CANCELLED:
-                self.cancel()
-            else:
-                self.set_exception(e)
-            raise
-
-    async def handle_accumulate_request(self, request: averaging_pb2.AveragingData) -> averaging_pb2.AveragingData:
-        """ respond to an incoming rpc_accumulate_part """
-        if self.state not in (ProtocolState.RUNNING_ALLREDUCE, ProtocolState.FOLLOWER_WAITING_FOR_LEADER):
-            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
-        elif request.group_id != self.group_id:
-            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
-        elif request.endpoint in self.accumulated_from:
-            return averaging_pb2.AveragingData(code=averaging_pb2.DUPLICATE_ENDPOINT)
-
-        if request.code in (averaging_pb2.INTERNAL_ERROR, averaging_pb2.CANCELLED):
-            self.set_exception(ValueError(f"{request.endpoint} sent {averaging_pb2.MessageCode.Name(request.code)}"))
-            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
-
-        try:
-            received_part = deserialize_torch_tensor(request.tensor_part)
-            averaged_part = await self.accumulate(request.endpoint, received_part)
-            serialized = serialize_torch_tensor(averaged_part, request.tensor_part.compression, allow_inplace=False)
-            return averaging_pb2.AveragingData(code=averaging_pb2.ACCEPTED, tensor_part=serialized)
-        except asyncio.CancelledError:
-            self.cancel()
-            return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
-        except Exception as e:
-            self.set_exception(e)
-            return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
-
-    def cancel(self):
-        logger.debug(f"{self} - cancelled")
-        self.state = ProtocolState.CANCELLED
-        for future in self.assembled_group, self.averaged_part, self.averaged_tensors:
-            future.cancel()
-
-    def set_exception(self, exception: Exception):
-        logger.debug(f"{self} - {exception}")
-        self.state = ProtocolState.ERROR
-        for future in self.assembled_group, self.averaged_part, self.averaged_tensors:
-            future.set_exception(exception)
-
-
-def split_into_parts(tensors: Sequence[torch.Tensor], group_size: int) -> Tuple[torch.Tensor]:
-    """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
-    flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
-    chunk_slices = torch.linspace(start=0, end=len(flat_tensor), steps=group_size + 1, dtype=torch.int64)
-    chunk_slices[-1] = len(flat_tensor)
-    return tuple(torch.as_tensor(flat_tensor[chunk_slices[i]: chunk_slices[i + 1]]) for i in range(group_size))
-
-
-def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:
-    """ restores the original tensor shapes from chunks obtained by split_into_chunks """
-    flat_tensor = torch.cat(list(chunks))
-    result_sizes = tuple(map(torch.Size.numel, shapes))
-    flat_original_tensors = torch.split_with_sizes(flat_tensor, result_sizes)
-    return tuple(map(torch.Tensor.reshape, flat_original_tensors, shapes))
-
-
-def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
-    """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
-    schema_dicts = [{field_name: str(field_value)
-                    for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
-                    for tensor in tensors]
-    return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()
-
-
-def is_valid_join_request(request: averaging_pb2.PeerInfo) -> bool:
-    assert len(request.ListFields()) == 3, "this function assumes JoinRequest has three fields, it should be updated"
-    return (isinstance(request.schema_hash, bytes) and
-            isinstance(request.expiration, DHTExpiration) and
-            isinstance(request.endpoint, Endpoint))

+ 0 - 185
hivemind/client/averager.py

@@ -1,185 +0,0 @@
-""" A background process that averages your tensors with peers """
-
-from __future__ import annotations
-
-import ctypes
-from typing import Sequence, Optional, Tuple, Any, Union, Awaitable, Dict
-from concurrent.futures.thread import ThreadPoolExecutor
-import multiprocessing as mp
-import asyncio
-
-import torch
-import uvloop
-import grpc
-
-import hivemind
-from hivemind.dht import get_dht_time, DHTExpiration
-from hivemind.utils import get_logger, Endpoint, Port, MPFuture
-from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
-from hivemind.client.allreduce import GroupAllReduce, GroupID
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc
-
-logger = get_logger(__file__)
-
-
-class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
-    """
-    **Warning!** Decentralized averager is in active development, some critical functionality is still underway
-
-    Gating function averaging service. A trainer can run this service in background to periodically average his gating
-    function with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
-    group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
-    Why averaging is valid: see https://github.com/learning-at-home/hivemind/issues/95#issuecomment-688806705
-    On global convergence: see https://github.com/learning-at-home/hivemind/issues/95#issuecomment-717719400
-
-    :param averaged_tensors: a sequence of pytorch tensors that will be averaged in each all-reduce
-    :param dht: a DHT node that will be used to find groups
-    :param start: if True, starts the background process immediately
-    :param timeout: consider allreduce failed if there was no activity for this many **seconds**
-    :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
-            if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
-    :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
-    :param receiver_threads: uses this many threads to await on input pipe. Default = 1 should be enough in most cases
-    :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
-          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
-    :param kwargs: extra parameters forwarded to in grpc.aio.server
-    You can perform averaging using DecentralizedOptimizer (see below) or by manually running each step as such:
-
-    >> TODO add a working example
-    """
-
-    def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
-                 max_size: int = None, timeout: float = 15, listen: bool = True, listen_on: Endpoint = '0.0.0.0:*',
-                 receiver_threads: int = 1, channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
-        super().__init__()
-        self.dht = dht
-        self.server_opts = listen, listen_on, receiver_threads, kwargs
-        self.max_size = max_size if max_size is not None else float('inf')
-        self.timeout = timeout
-        self.channel_options = channel_options
-        self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
-        self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
-        self._pending_groups: Dict[GroupID, GroupAllReduce] = {}
-        self._lock_forming_a_group: Optional[asyncio.Lock] = None
-        self.ready = mp.Event()
-
-        self.averaged_tensors = tuple(averaged_tensors)
-        for tensor in self.averaged_tensors:
-            assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
-            tensor.share_memory_()
-
-        if start:
-            self.run_in_background(await_ready=True)
-
-    @property
-    def port(self) -> Optional[Port]:
-        return self._port.value if self._port.value != 0 else None
-
-    def run(self):
-        """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
-        if asyncio.get_event_loop().is_running():
-            asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
-
-        uvloop.install()
-        loop = asyncio.new_event_loop()
-        asyncio.set_event_loop(loop)
-
-        listen, listen_on, receiver_threads, server_kwargs = self.server_opts
-        pipe_awaiter = ThreadPoolExecutor(receiver_threads)
-        self._lock_forming_a_group = asyncio.Lock()
-
-        async def _run():
-            if listen:
-                grpc.aio.init_grpc_aio()
-                server = grpc.aio.server(**server_kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-                averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
-                found_port = server.add_insecure_port(listen_on)
-                assert found_port != 0, f"Failed to listen to {listen_on}"
-                self._port.value = found_port
-                await server.start()
-                self.ready.set()
-            else:
-                raise NotImplementedError("Client-only averaging is not implemented yet.")
-
-            while True:
-                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
-                asyncio.create_task(getattr(self, method)(*args, **kwargs))
-
-        loop.run_until_complete(_run())
-
-    def run_in_background(self, await_ready=True, timeout=None):
-        """
-        Starts averager in a background process. if await_ready, this method will wait until background dht
-        is ready to process incoming requests or for :timeout: seconds max.
-        """
-        self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
-
-    def shutdown(self) -> None:
-        """ Shut down the averager process """
-        # TODO notify peers before terminating
-        if self.is_alive():
-            self.terminate()
-        else:
-            logger.warning("DHT shutdown has no effect: the process is not alive")
-
-    def group_allreduce(self, my_endpoint: Endpoint, leader_endpoint: Optional[Endpoint] = None,
-                        return_future=False) -> Union[Sequence[torch.Tensor], Awaitable[Sequence[torch.Tensor]]]:
-        """
-        Set up the averager to look for a group and run all-reduce once, optionally await and return outcome
-
-        :note: this function implemented for debugging and will be removed in future versions
-        :param my_endpoint: public endpoint of this averager
-        :param leader_endpoint: if specified, attempts to join this peer's group
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        """
-        expiration = get_dht_time() + self.timeout
-        assert isinstance(expiration, DHTExpiration)
-
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_group_allreduce', [], dict(my_endpoint=my_endpoint, expiration=expiration,
-                                                     leader_endpoint=leader_endpoint, future=_future)))
-        return future if return_future else future.result()
-
-    async def _group_allreduce(self, *, my_endpoint: Endpoint, expiration: DHTExpiration,
-                               leader_endpoint: Optional[Endpoint], future: MPFuture):
-        group_allreduce = GroupAllReduce(my_endpoint, expiration, self.averaged_tensors)
-        try:
-            if leader_endpoint is None:
-                async with self._lock_forming_a_group:
-                    group_allreduce.start_new_group(max_size=self.max_size)
-                    self._forming_group = self._pending_groups[group_allreduce.group_id] = group_allreduce
-                    await asyncio.wait_for(group_allreduce.assembled_group, expiration - get_dht_time())
-
-                future.set_result(await group_allreduce.run_allreduce())
-            else:
-                async with self._lock_forming_a_group:
-                    stream = await group_allreduce.request_join_group(leader_endpoint)
-                    self._forming_group = self._pending_groups[group_allreduce.group_id] = group_allreduce
-
-                started_allreduce = await group_allreduce.wait_for_allreduce(stream)
-                if started_allreduce:
-                    future.set_result(await group_allreduce.run_allreduce())
-                else:
-                    future.set_exception(ValueError(f"Rejected by {leader_endpoint}"))
-
-        except Exception as e:
-            future.set_exception(e)
-        finally:
-            _ = self._pending_groups.pop(group_allreduce.group_id, None)
-            if group_allreduce is self._forming_group:
-                self._forming_group = None
-
-    async def rpc_group_allreduce(self, request: averaging_pb2.PeerInfo, context: grpc.ServicerContext):
-        """ A peer wants me to be his leader. I will coordinate his actions with the rest of my group. Maybe. """
-        if self._forming_group is None:
-            yield averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
-            return
-        async for message in self._forming_group.handle_join_request(request):
-            yield message
-
-    async def rpc_aggregate_part(self, request: averaging_pb2.AveragingData, context: grpc.ServicerContext):
-        if request.group_id not in self._pending_groups:
-            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
-        return await self._pending_groups[request.group_id].handle_accumulate_request(request)

+ 216 - 0
hivemind/client/averaging/__init__.py

@@ -0,0 +1,216 @@
+""" A background process that averages your tensors with peers """
+
+from __future__ import annotations
+
+import random
+import ctypes
+from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
+from concurrent.futures.thread import ThreadPoolExecutor
+import multiprocessing as mp
+import asyncio
+
+import torch
+import uvloop
+import grpc
+
+import hivemind
+from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
+from hivemind.client.averaging.matchmaking import Matchmaking
+from hivemind.utils import get_logger, Endpoint, Port, MPFuture, replace_port, GRPC_KEEPALIVE_OPTIONS
+from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
+
+# flavour types
+StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
+
+INITIAL_GROUP_NBITS = 3
+logger = get_logger(__file__)
+
+
+class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
+    """
+    **Warning!** Decentralized averager is in active development, some critical functionality is still underway
+
+    Parameter averaging service. A trainer can run this service in background to periodically average his parameters
+    with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
+    group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
+
+    :param averaged_tensors: a sequence of pytorch tensors that will be averaged in each all-reduce
+    :param dht: a DHT node that will be used to find groups
+    :param start: if True, starts the background process immediately
+
+    :param prefix: a shared prefix for all group keys
+    :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
+    :param initial_group_bits: a string of bits ('0' and '1') that define initial group key (bucket index)
+      by default, sample a random bit sequence of length {INITIAL_GROUP_NBITS}
+    :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
+      note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
+    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
+    :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
+
+    :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
+            if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
+    :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
+    :param receiver_threads: uses this many threads to await on input pipe. Default = 1 should be enough in most cases
+    :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
+          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
+    :param kwargs: extra parameters forwarded to grpc.aio.server
+    You can perform averaging using DecentralizedOptimizer (see below) or by manually running each step as such:
+
+    >> TODO add a working example here
+    """
+    _matchmaking: Matchmaking
+    _pending_group_assembled: asyncio.Event
+
+    def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
+                 prefix: str, target_group_size: int, min_group_size: int = 1, initial_group_bits: Optional[str] = None,
+                 averaging_expiration: float = 15, allreduce_timeout: Optional[float] = None,
+                 compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+                 listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1,
+                 channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
+        assert '.' not in prefix, "group prefix must be a string without ."
+        if is_power_of_two(target_group_size):
+            logger.warning("It is recommended to set target_group_size to a power of 2.")
+        if initial_group_bits is None:
+            initial_group_bits = ''.join(random.choices('01', k=INITIAL_GROUP_NBITS))
+            logger.debug(f"Initializing with random {INITIAL_GROUP_NBITS}-bit group index: {initial_group_bits}")
+        assert len(initial_group_bits) >= INITIAL_GROUP_NBITS and all(bit in '01' for bit in initial_group_bits)
+
+        super().__init__()
+        self.dht = dht
+        self.listen_on, self.receiver_threads, self.kwargs = listen_on, receiver_threads, kwargs
+        self.channel_options = channel_options
+        self.averaged_tensors = tuple(averaged_tensors)
+        # TODO use mp.Lock to prevent someone from modifying tensors before we copy them! maybe.
+        for tensor in self.averaged_tensors:
+            assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
+            tensor.share_memory_()
+
+        self.matchmaking_kwargs = dict(prefix=prefix, initial_group_bits=initial_group_bits,
+                                       target_group_size=target_group_size, min_group_size=min_group_size,
+                                       averaging_expiration=averaging_expiration)
+        self.allreduce_timeout, self.compression_type = allreduce_timeout, compression_type
+        self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
+
+        self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
+        self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
+        self._averager_endpoint: Optional[Endpoint] = None
+        self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
+
+        if start:
+            self.run_in_background(await_ready=True)
+
+    @property
+    def port(self) -> Optional[Port]:
+        return self._port.value if self._port.value != 0 else None
+
+    @property
+    def endpoint(self) -> Endpoint:
+        if self._averager_endpoint is None:
+            self._averager_endpoint = replace_port(self.listen_on, self.port if self.port is not None else '*')
+            logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
+        return self._averager_endpoint
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({self.endpoint})"
+
+    def run(self):
+        """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
+        if asyncio.get_event_loop().is_running():
+            asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
+
+        uvloop.install()
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(loop)
+
+        # initialize asyncio synchronization primitives in this event loop
+        pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
+
+        async def _run():
+            grpc.aio.init_grpc_aio()
+            server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
+            averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
+            found_port = server.add_insecure_port(self.listen_on)
+            assert found_port != 0, f"Failed to listen to {self.listen_on}"
+            self._port.value = found_port
+            self._matchmaking = Matchmaking(self.endpoint, self.averaged_tensors, self.dht, **self.matchmaking_kwargs)
+            self._pending_group_assembled = asyncio.Event()
+            self._pending_group_assembled.set()
+            await server.start()
+            self.ready.set()
+
+            while True:
+                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
+                asyncio.create_task(getattr(self, method)(*args, **kwargs))
+
+        loop.run_until_complete(_run())
+
+    def run_in_background(self, await_ready=True, timeout=None):
+        """
+        Starts averager in a background process. if await_ready, this method will wait until background dht
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready and not self.ready.wait(timeout=timeout):
+            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
+
+    def shutdown(self) -> None:
+        """ Shut down the averager process """
+        # TODO notify peers before terminating
+        if self.is_alive():
+            self.terminate()
+        else:
+            logger.warning("DHT shutdown has no effect: the process is not alive")
+
+    def step(self, timeout: Optional[float] = None, return_future=False) -> Union[Sequence[torch.Tensor], MPFuture]:
+        """
+        Set up the averager to look for a group and run one round of averaging, then return the averaged tensors
+
+        :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_step', [], dict(future=_future, timeout=timeout)))
+        return future if return_future else future.result()
+
+    async def _step(self, *, future: MPFuture, timeout: Optional[float]):
+        group_id = None
+        try:
+            self._pending_group_assembled.clear()
+            allreduce_group = await self._matchmaking.look_for_group(timeout=timeout)
+            group_id = allreduce_group.group_id
+            if allreduce_group is not None:
+                self._running_groups[group_id] = allreduce_group
+                self._pending_group_assembled.set()
+                future.set_result(await asyncio.wait_for(allreduce_group.run(), self.allreduce_timeout))
+            else:
+                raise AllreduceException(f"{self} - group_allreduce failed, unable to find a group")
+
+        except Exception as e:
+            future.set_exception(e)
+            raise
+        finally:
+            self._pending_group_assembled.set()
+            if group_id is not None:
+                _ = self._running_groups.pop(group_id, None)
+
+    async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
+                             ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
+        """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """
+        async for response in self._matchmaking.rpc_join_group(request, context):
+            yield response
+
+    async def rpc_aggregate_part(self, request: averaging_pb2.AveragingData, context: grpc.ServicerContext):
+        """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
+        if request.group_id not in self._running_groups and not self._pending_group_assembled.is_set():
+            # this handles a special case when leader accepted us to group AND began allreduce right away,
+            # but his response with group_id was delayed and other peers got to us first
+            await self._pending_group_assembled.wait()
+        if request.group_id not in self._running_groups:
+            return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+        else:
+            return await self._running_groups[request.group_id].rpc_aggregate_part(request, context)
+
+
+def is_power_of_two(n):
+    """ Check whether n is a power of 2 """
+    return (n != 0) and (n & (n - 1) == 0)

+ 184 - 0
hivemind/client/averaging/allreduce.py

@@ -0,0 +1,184 @@
+import asyncio
+from typing import Sequence, Set, Dict, Tuple
+
+import grpc
+import torch
+
+from hivemind.utils import Endpoint, get_logger, serialize_torch_tensor, deserialize_torch_tensor, ChannelCache
+from hivemind.proto import averaging_pb2_grpc, runtime_pb2, averaging_pb2
+
+# flavour types
+GroupID = bytes
+logger = get_logger(__name__)
+
+
+class AllReduceProtocol:
+    """
+    An internal class that runs butterfly AllReduce in a predefined group of averagers
+
+    :param tensors: local tensors that should be averaged with groupmates
+    :param endpoint: your endpoint, must be included in ordered_group_endpoints
+    :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
+    """
+    def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
+                 ordered_group_endpoints: Sequence[Endpoint]):
+        assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
+        self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
+        self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, self.group_size)))
+        self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
+
+        self.accumulator = self.local_tensor_parts[self.endpoint].clone()  # sum inputs from peers to this tensor
+        self.accumulated_from: Set[Endpoint] = {self.endpoint}  # peers that we have accumulated our part from
+        self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
+        self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
+        self.averaged_tensors: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
+
+    def __await__(self):
+        return self.averaged_tensors.__await__()
+
+    @property
+    def group_size(self):
+        return len(self.ordered_group_endpoints)
+
+    async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor) -> torch.Tensor:
+        """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
+        assert not self.averaged_part.done(), f"already finished averaging part: {self.averaged_part}"
+        assert not self.averaged_tensors.done(), f"already finished allreduce: {self.averaged_tensors}"
+        assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
+        assert source not in self.accumulated_from, "duplicate source, already received that part"
+        logger.debug(f"{self} - accumulating tensor part from {source}")
+
+        self.accumulator.add_(remote_part)
+        self.accumulated_from.add(source)
+
+        assert len(self.accumulated_from) <= self.group_size
+        if len(self.accumulated_from) == len(self.local_tensor_parts):
+            average_result = self.accumulator.div_(len(self.accumulated_from))
+            self.register_averaged_part(self.endpoint, average_result)
+            self.averaged_part.set_result(average_result)
+
+        return await self.averaged_part
+
+    def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
+        assert not self.averaged_tensors.done(), f"already finished allreduce: {self.averaged_tensors}"
+        assert source in self.local_tensor_parts, "the provider of averaged part is not from my group"
+        assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
+        assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
+        assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
+        logger.debug(f"{self} - receiving averaged tensor part from {source}")
+        self.averaged_tensor_parts[source] = averaged_part
+        if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
+            ordered_averaged_parts = [self.averaged_tensor_parts[endpoint] for endpoint in self.ordered_group_endpoints]
+            self.averaged_tensors.set_result(restore_from_parts(ordered_averaged_parts, self.tensor_shapes))
+
+    def cancel(self) -> bool:
+        if not self.averaged_tensors.done():
+            logger.debug(f"{self} - cancelled")
+            self.averaged_tensors.cancel()
+            if not self.averaged_part.done():
+                self.averaged_part.cancel()
+            return True
+        else:
+            logger.debug(f"{self} - failed to cancel, allreduce is already finished: {self.averaged_tensors}")
+            return False
+
+    def set_exception(self, exception: Exception) -> bool:
+        if not self.averaged_tensors.done():
+            logger.debug(f"{self} - {exception}")
+            self.averaged_tensors.set_exception(exception)
+            if not self.averaged_part.done():
+                self.averaged_part.cancel()
+            return True
+        else:
+            logger.debug(f"{self} - failed to set {exception}, allreduce already finished: {self.averaged_tensors}")
+            return False
+
+
+class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragingServicer):
+    """
+    A class that implements ButterflyAllReduceProtocol on top of a gRPC servicer
+    """
+    def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
+                 ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType):
+        super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint,
+                         ordered_group_endpoints=ordered_group_endpoints)
+        self.compression_type = compression_type
+
+    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
+        return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+
+    async def _average_one_part(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
+        """ Send one part of local tensors to one groupmate and collect the average for this part """
+        serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
+        response = await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(
+            averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id,
+                                        endpoint=self.endpoint, tensor_part=serialized_tensor_part))
+        if response.code == averaging_pb2.AVERAGED_PART:
+            averaged_part = deserialize_torch_tensor(response.tensor_part)
+            self.register_averaged_part(peer_endpoint, averaged_part)
+            return averaged_part
+        else:
+            raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(response.code)}"
+                                     f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
+                                     f" allreduce failed")
+
+    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
+        await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
+            group_id=self.group_id, endpoint=self.endpoint, code=code))
+
+    async def run(self) -> Sequence[torch.Tensor]:
+        """ send allreduce requests to all peers and collect results, return the averaged tensor """
+        try:
+            await asyncio.gather(self, *(self._average_one_part(peer, part)
+                                         for peer, part in self.local_tensor_parts.items() if peer != self.endpoint))
+            return await self
+        except Exception as e:
+            code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
+            logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
+            self.set_exception(e)
+            for peer_endpoint in self.ordered_group_endpoints:
+                asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
+            raise
+
+    async def rpc_aggregate_part(self, request: averaging_pb2.AveragingData, context: grpc.ServicerContext):
+        """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
+        if request.group_id != self.group_id:
+            return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+
+        if request.code == averaging_pb2.PART_FOR_AVERAGING:
+            try:
+                tensor_part = deserialize_torch_tensor(request.tensor_part)
+                averaged_part = await self.accumulate_part(request.endpoint, tensor_part)
+                serialized = serialize_torch_tensor(averaged_part, request.tensor_part.compression, allow_inplace=False)
+                return averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized)
+            except Exception as e:
+                self.set_exception(e)
+                return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+        else:
+            error_code = averaging_pb2.MessageCode.Name(request.code)
+            logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
+            self.set_exception(AllreduceException(f"peer {request.endpoint} sent {error_code}."))
+            return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+
+
+def split_into_parts(tensors: Sequence[torch.Tensor], group_size: int) -> Tuple[torch.Tensor, ...]:
+    """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
+    flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
+    chunk_slices = torch.linspace(start=0, end=len(flat_tensor), steps=group_size + 1, dtype=torch.int64)
+    chunk_slices[-1] = len(flat_tensor)
+    return tuple(flat_tensor[chunk_slices[i]: chunk_slices[i + 1]] for i in range(group_size))
+
+
+def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:
+    """ restores the original tensor shapes from chunks obtained by split_into_chunks """
+    flat_tensor = torch.cat(tuple(chunks))
+    result_sizes = tuple(map(torch.Size.numel, shapes))
+    flat_original_tensors = torch.split_with_sizes(flat_tensor, result_sizes)
+    return tuple(map(torch.Tensor.reshape, flat_original_tensors, shapes))
+
+
+class AllreduceException(Exception):
+    """ A special exception that is raised when allreduce can't continue normally (e.g. disbanded/bad request/etc) """

+ 394 - 0
hivemind/client/averaging/matchmaking.py

@@ -0,0 +1,394 @@
+""" A background process that averages your tensors with peers """
+
+from __future__ import annotations
+
+import contextlib
+import random
+from dataclasses import asdict
+from math import isfinite
+from typing import Sequence, Optional, AsyncIterator, Set
+import asyncio
+
+import torch
+import grpc
+
+import hivemind
+from hivemind.client.averaging.allreduce import AllReduceRunner, GroupID
+from hivemind.dht import DHTID, DHTExpiration, get_dht_time, GroupKey
+from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, TimedStorage
+from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
+from hivemind.utils.grpc import ChannelCache
+
+
+logger = get_logger(__file__)
+
+
+class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
+    f"""
+    An internal class that is used to form groups of averages for running allreduce
+    See DecentralizedAverager docstring for the detailed description of all parameters
+    """
+
+    def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *,
+                 prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
+                 averaging_expiration: float = 15, compression_type: runtime_pb2.CompressionType = runtime_pb2.NONE):
+        assert '.' not in prefix, "group prefix must be a string without ."
+
+        super().__init__()
+        self.dht, self.endpoint, self.averaged_tensors = dht, endpoint, tuple(averaged_tensors)
+        self.prefix, self.group_bits = prefix, initial_group_bits
+        self.target_group_size, self.min_group_size = target_group_size, min_group_size
+        self.averaging_expiration, self.compression_type = averaging_expiration, compression_type
+
+        self.schema_hash = compute_schema_hash(self.averaged_tensors)
+
+        self.lock_looking_for_group = asyncio.Lock()
+        self.lock_request_join_group = asyncio.Lock()
+        self.cond_notify_followers = asyncio.Condition()
+        self.assembled_group = asyncio.Future()
+
+        self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
+        self.current_followers: Set[Endpoint] = set()  # iff i am a leader, this contains my followers excluding myself
+        self.potential_leaders = PotentialLeaders(self.endpoint, self.dht, self.averaging_expiration)
+
+    @property
+    def is_looking_for_group(self):
+        return self.lock_looking_for_group.locked()
+
+    @property
+    def current_group_key(self) -> GroupKey:
+        return f"{self.prefix}.0b{self.group_bits}"
+
+    def __repr__(self):
+        lfg_status = "looking for group," if self.is_looking_for_group else "not looking for group,"
+        if self.is_looking_for_group:
+            if self.current_leader:
+                lfg_status += f" following {self.current_leader},"
+            if len(self.current_followers):
+                lfg_status += f" leading {len(self.current_followers)} followers,"
+        schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
+        return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
+               f" current key = {self.current_group_key})"
+
+    async def look_for_group(self, *, timeout: Optional[float] = None) -> AllReduceRunner:
+        """
+        :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
+        Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
+        """
+        if self.is_looking_for_group:
+            logger.info("Another look_for_group is already in progress. The current run will be scheduled after"
+                        " the existing group is either assembled or disbanded.")
+        async with self.lock_looking_for_group:
+            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
+            try:
+                return await asyncio.wait_for(self.assembled_group, timeout=timeout)
+            except Exception as e:
+                if len(self.current_followers) > 0:
+                    async with self.lock_request_join_group:
+                        await self.leader_disband_group()
+                self.assembled_group.set_exception(e)
+                raise
+
+            finally:
+                if not request_leaders_task.done():
+                    request_leaders_task.cancel()
+                if self.assembled_group.done():
+                    self.assembled_group = asyncio.Future()
+
+    async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
+        """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
+        end_time = get_dht_time() + timeout if timeout is not None else float('inf')
+        async with self.potential_leaders.begin_search(self.current_group_key, timeout):
+            # TODO update group_bits on success! reduce number of bits on not enough peers.
+            # TODO after allreduce finishes, we may need to ask leader to notify lower keys about this
+            # (so as to fix possible network partitioning if some peers operate on a much smaller nbits)
+            while True:
+                try:
+                    time_to_expiration = self.potential_leaders.declared_expiration_time - get_dht_time()
+                    next_best_leader = await asyncio.wait_for(
+                        self.potential_leaders.pop_next_leader(),
+                        timeout=time_to_expiration if isfinite(time_to_expiration) else None)
+
+                    request_expiration_time = min(self.potential_leaders.declared_expiration_time,
+                                                  end_time, get_dht_time() + self.averaging_expiration)
+                    group = await self.request_join_group(next_best_leader, request_expiration_time)
+                    if group is not None:
+                        return group
+
+                except asyncio.TimeoutError:
+                    async with self.lock_request_join_group:
+                        if len(self.current_followers) >= self.min_group_size:
+                            # the time is up, we have a *good enough* group. run allreduce as is.
+                            return await self.leader_assemble_group()
+                        else:
+                            await self.leader_disband_group()
+                            # TODO maybe adjust grid size
+                            continue
+
+    async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[AllReduceRunner]:
+        """
+        :param leader: request this peer to be your leader for allreduce
+        :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
+        :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
+        :note: this function does not guarantee that your group leader is the same as :leader: parameter
+          The originally specified leader can disband group and redirect us to a different leader
+        """
+        assert self.is_looking_for_group and self.current_leader is None
+        call: Optional[grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]] = None
+        try:
+            async with self.lock_request_join_group:
+                leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+                call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
+                    endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time))
+
+                message = await call.read()
+                if message.code != averaging_pb2.ACCEPTED:
+                    code = averaging_pb2.MessageCode.Name(message.code)
+                    logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
+                    return None
+
+                # else: we were accepted
+                logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
+                self.current_leader = leader
+                if len(self.current_followers) > 0:
+                    await self.leader_disband_group()
+
+            async with self.potential_leaders.pause_search():
+                message = await call.read()
+
+            if message.code == averaging_pb2.BEGIN_ALLREDUCE:
+                async with self.lock_request_join_group:
+                    return await self.follower_assemble_group(leader, message.group_id, message.ordered_group_endpoints)
+            elif message.code == averaging_pb2.GROUP_DISBANDED and bool(message.suggested_leader):
+                logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
+                return await self.request_join_group(message.suggested_leader, expiration_time)
+
+            else:
+                logger.debug(f"{self} - leader sent {averaging_pb2.MessageCode.Name(message.code)}, leaving group")
+                return None
+        finally:
+            self.current_leader = None
+            if call is not None:
+                call.cancel()
+
+    async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
+                             ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
+        """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """
+        try:
+            reason_to_reject = self._check_reasons_to_reject(request)
+            if reason_to_reject is not None:
+                yield reason_to_reject
+                return
+
+            current_group = self.assembled_group  # copy current assembled_group to avoid overwriting
+            async with self.lock_request_join_group:
+                self.current_followers.add(request.endpoint)
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
+
+                if len(self.current_followers) + 1 >= self.target_group_size:
+                    # outcome 1: we have assembled a full group and are ready for allreduce
+                    await self.leader_assemble_group()
+
+            if not current_group.done():
+                try:
+                    async with self.cond_notify_followers:
+                        # wait for the group to be assembled or disbanded
+                        timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time())
+                        await asyncio.wait_for(self.cond_notify_followers.wait(), timeout=timeout)
+                except asyncio.TimeoutError:
+                    async with self.lock_request_join_group:
+                        # outcome 2: the time is up, run allreduce with what we have or disband
+                        if len(self.current_followers) + 1 >= self.min_group_size and self.is_looking_for_group:
+                            await self.leader_assemble_group()
+                        else:
+                            await self.leader_disband_group()
+
+            if self.current_leader is not None:
+                # outcome 3: found by a leader with higher priority, send our followers to him
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED,
+                                                      suggested_leader=self.current_leader)
+                return
+
+            if request.endpoint not in self.current_followers:
+                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
+                return
+
+            # finally, run allreduce
+            allreduce_group = current_group.result()
+            yield averaging_pb2.MessageFromLeader(
+                code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
+                ordered_group_endpoints=allreduce_group.ordered_group_endpoints)
+
+        except Exception as e:
+            logger.exception(e)
+            yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
+
+        finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
+            self.current_followers.discard(request.endpoint)
+
+    def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> averaging_pb2.MessageFromLeader:
+        """ :returns: if accepted, return None, otherwise return a reason for rejection """
+        if not self.is_looking_for_group:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
+
+        if request.ListFields() == 3 and not isinstance(request.schema_hash, bytes) or len(request.schema_hash) == 0 \
+                or not isinstance(request.expiration, DHTExpiration) or not isfinite(request.expiration) \
+                or not isinstance(request.endpoint, Endpoint) or len(request.endpoint) == 0:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
+
+        elif request.schema_hash != self.schema_hash:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
+        elif self.potential_leaders.declared_group_key is None:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
+        elif self.potential_leaders.declared_expiration_time > (request.expiration or float('inf')):
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
+        elif self.current_leader is not None:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER,
+                                                   suggested_leader=self.current_leader)
+        elif request.endpoint == self.endpoint or request.endpoint in self.current_followers:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
+        elif len(self.current_followers) + 1 >= self.target_group_size:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
+        else:
+            return None
+
+    async def leader_assemble_group(self) -> AllReduceRunner:
+        """ Form up all current followers into a group and prepare to _run_allreduce """
+        assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
+        group_id = DHTID.generate().to_bytes()
+        ordered_group_endpoints = list(self.current_followers)
+        ordered_group_endpoints.append(self.endpoint)
+        random.shuffle(ordered_group_endpoints)
+        logger.debug(f"{self.endpoint} - leader started allreduce with {len(ordered_group_endpoints)} followers.")
+        allreduce_group = AllReduceRunner(
+            group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
+            ordered_group_endpoints=ordered_group_endpoints, compression_type=self.compression_type)
+        self.assembled_group.set_result(allreduce_group)
+        async with self.cond_notify_followers:
+            self.cond_notify_followers.notify_all()
+        return allreduce_group
+
+    async def follower_assemble_group(self, leader: Endpoint, group_id: GroupID,
+                                      ordered_group_endpoints: Sequence[Endpoint]) -> AllReduceRunner:
+        """ Prepare to run allreduce using a list of peers provided by our leader """
+        assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
+        logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
+        assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
+        assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
+        allreduce_group = AllReduceRunner(
+            group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
+            ordered_group_endpoints=ordered_group_endpoints, compression_type=self.compression_type)
+        self.assembled_group.set_result(allreduce_group)
+        async with self.cond_notify_followers:
+            self.cond_notify_followers.notify_all()
+        return allreduce_group
+
+    async def leader_disband_group(self):
+        """ Kick out all followers immediately, optionally direct them to our new leader (if we found one) """
+        assert self.lock_request_join_group.locked()
+        self.current_followers.clear()  # this will cause rpc_join_group to kick all followers out
+        async with self.cond_notify_followers:
+            self.cond_notify_followers.notify_all()
+
+
+class PotentialLeaders:
+    """ An utility class that searches for averagers that could become our leaders """
+    def __init__(self, endpoint: Endpoint, dht: hivemind.DHT, averaging_expiration: DHTExpiration):
+        self.endpoint, self.dht, self.averaging_expiration = endpoint, dht, averaging_expiration
+        self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
+        self.leader_queue = TimedStorage[Endpoint, DHTExpiration]()
+        self.max_assured_time = float('-inf')
+        self.declared_expiration_time = float('inf')
+        self.declared_group_key: Optional[GroupKey] = None
+        self.search_end_time = float('inf')
+
+    @contextlib.asynccontextmanager
+    async def begin_search(self, group_key: GroupKey, timeout: Optional[float]):
+        assert not self.running.is_set(), "already running"
+        self.running.set()
+        self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
+        update_queue_task = asyncio.create_task(self._update_queue_periodically(group_key))
+        declare_averager_task = asyncio.create_task(self._declare_averager_periodically(group_key))
+        try:
+            yield self
+        finally:
+            update_queue_task.cancel()
+            declare_averager_task.cancel()
+            self.running.clear()
+            self.update_triggered.clear()
+            self.update_finished.clear()
+
+    @contextlib.asynccontextmanager
+    async def pause_search(self):
+        was_running = self.running.is_set()
+        try:
+            self.running.clear()
+            yield
+        finally:
+            if was_running:
+                self.running.set()
+            else:
+                self.running.clear()
+
+    async def pop_next_leader(self) -> Endpoint:
+        """ Remove and return the next most suitable leader or throw an exception if reached timeout """
+        assert self.running, "Not running search at the moment"
+        maybe_next_leader, entry = self.leader_queue.top()
+
+        next_entry_time = entry.expiration_time if maybe_next_leader is not None else get_dht_time()
+        if self.max_assured_time < next_entry_time < self.search_end_time:
+            self.update_triggered.set()
+
+        if maybe_next_leader is None:
+            await self.update_finished.wait()
+            return await self.pop_next_leader()
+
+        del self.leader_queue[maybe_next_leader]
+        return maybe_next_leader
+
+    async def _update_queue_periodically(self, group_key: GroupKey):
+        DISCREPANCY = hivemind.utils.timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
+        while get_dht_time() < self.search_end_time:
+            new_peers = await self.dht.get_averagers(group_key, only_active=True, return_future=True)
+            self.max_assured_time = max(self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY)
+
+            for peer, peer_expiration_time in new_peers:
+                if peer == self.endpoint:
+                    continue
+                self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
+                self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
+
+            if len(self.leader_queue) > 0:
+                self.update_finished.set()
+
+            await asyncio.wait(
+                {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
+                timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None)
+            self.update_triggered.clear()
+
+    async def _declare_averager_periodically(self, group_key: GroupKey):
+        try:
+            while True:
+                new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
+                self.declared_group_key, self.declared_expiration_time = group_key, new_expiration_time
+                stored_ok = await self.dht.declare_averager(group_key, self.endpoint, new_expiration_time,
+                                                            looking_for_group=True, return_future=True)
+                if stored_ok:
+                    await asyncio.sleep(self.declared_expiration_time - get_dht_time())
+                else:
+                    logger.warning(f"Failed to subscribe to group {group_key} : store rejected by DHT peers")
+        finally:
+            if self.declared_group_key is not None:
+                previous_declared_key, previous_expiration_time = self.declared_group_key, self.declared_expiration_time
+                self.declared_group_key, self.declared_expiration_time = None, float('inf')
+                self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float('-inf')
+                await self.dht.declare_averager(previous_declared_key, self.endpoint, previous_expiration_time,
+                                                looking_for_group=False, return_future=True)
+
+
+def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
+    """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
+    schema_dicts = [{field_name: str(field_value)
+                    for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
+                    for tensor in tensors]
+    return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()

+ 75 - 0
hivemind/dht/__init__.py

@@ -22,6 +22,7 @@ from concurrent.futures import ThreadPoolExecutor
 from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
 from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
 
 
 import uvloop
 import uvloop
+from numpy import nextafter
 
 
 from hivemind.client import RemoteExpert
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
@@ -37,16 +38,25 @@ FLAT_EXPERT = -1     # grid prefix reserved for storing 1d expert uids. Used to
 UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
 UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
 PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
 PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
 #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
 #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
+GroupKey = str
+GROUP_PATTERN = re.compile('^(([^.])+)[.]0b[01]+$')  # e.g. bert_exp4_averaging.0b01001101
 
 
 
 
 def is_valid_uid(maybe_uid: str) -> bool:
 def is_valid_uid(maybe_uid: str) -> bool:
+    """ An uid must contain a string expert type, followed by one or more .-separated numeric indices """
     return bool(UID_PATTERN.fullmatch(maybe_uid))
     return bool(UID_PATTERN.fullmatch(maybe_uid))
 
 
 
 
 def is_valid_prefix(maybe_prefix: str) -> bool:
 def is_valid_prefix(maybe_prefix: str) -> bool:
+    """ An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period """
     return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
     return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
 
 
 
 
+def is_valid_group(maybe_group: str) -> bool:
+    """ A group identifier must contain group type, followed by one or more .-separated indices, and any ?metadata"""
+    return bool(GROUP_PATTERN.fullmatch(maybe_group))
+
+
 def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
 def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
     """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
     """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
@@ -118,6 +128,7 @@ class DHT(mp.Process):
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
                  receiver_threads: int = 1, negative_caching: bool = True, expiration: float = 300, **kwargs):
                  receiver_threads: int = 1, negative_caching: bool = True, expiration: float = 300, **kwargs):
         super().__init__()
         super().__init__()
+        assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.receiver_threads, self.max_workers, self.parallel_rpc = receiver_threads, max_workers, parallel_rpc
         self.receiver_threads, self.max_workers, self.parallel_rpc = receiver_threads, max_workers, parallel_rpc
         self.expiration, self.negative_caching = expiration, negative_caching
         self.expiration, self.negative_caching = expiration, negative_caching
@@ -457,3 +468,67 @@ class DHT(mp.Process):
         if future is not None:
         if future is not None:
             future.set_result(best_experts_batch)
             future.set_result(best_experts_batch)
         return best_experts_batch
         return best_experts_batch
+
+    def declare_averager(self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float, *,
+                         looking_for_group: bool = True, return_future: bool = False) -> Union[bool, MPFuture]:
+        """
+        Add (or remove) the averager to a given allreduce bucket
+
+        :param group_key: allreduce group key, e.g. my_averager.0b011011101
+        :param endpoint: averager public endpoint for incoming requests
+        :param expiration_time: intent to run allreduce before this timestamp
+        :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
+          If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
+        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
+        :return: True if declared, False if declaration was rejected by DHT peers
+        :note: when leaving (i.e. is_active=False), please specify the same expiration_time as when entering the group
+        :note: setting is_active=False does *not* guarantee that others will immediately stop to query you.
+        """
+        assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_declare_averager', [],
+                        dict(group_key=group_key, endpoint=endpoint, expiration_time=expiration_time,
+                             looking_for_group=looking_for_group, future=_future)))
+        return future if return_future else future.result()
+
+    async def _declare_averager(self, node: DHTNode, *, group_key: str, endpoint: Endpoint,
+                                expiration_time: DHTExpiration, looking_for_group: bool, future: MPFuture):
+        try:
+            expiration_time = expiration_time if looking_for_group else nextafter(expiration_time, float('inf'))
+            # ^-- when declaring averager inactive, we increment expiration time to overwrite the pre-existing entry
+            store_ok = await node.store(
+                key=group_key, subkey=endpoint, value=looking_for_group, expiration_time=expiration_time)
+            future.set_result(store_ok)
+        except Exception as e:
+            future.set_exception(e)
+
+    def get_averagers(self, group_key: GroupKey, *, only_active: bool = True, return_future: bool = False
+                      ) -> Union[List[Tuple[Endpoint, DHTExpiration]], MPFuture]:
+        """
+        Find and return averagers in a specified all-reduce bucket
+
+        :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
+        :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
+            if False, return all averagers under a given group_key regardless of value
+        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
+        :return: endpoints and expirations of every matching averager
+        """
+        assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_get_averagers', [], dict(group_key=group_key, only_active=only_active, future=_future)))
+        return future if return_future else future.result()
+
+    async def _get_averagers(self, node: DHTNode, *, group_key: str, only_active: bool, future: MPFuture):
+        try:
+            result = await node.get(group_key, latest=True)
+            if result is None:
+                logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
+                future.set_result([])
+                return
+            assert isinstance(result.value, dict), f"expected {group_key} to be a Dict[Endpoint, is_active], " \
+                                                   f"but got {result.value} of type {type(result.value)}."
+            averagers = [(endpoint, entry.expiration_time) for endpoint, entry in result.value.items()
+                         if not only_active or entry.value is True]
+            future.set_result(averagers)
+        except Exception as e:
+            future.set_exception(e)

+ 23 - 20
hivemind/proto/averaging.proto

@@ -4,34 +4,37 @@ import "runtime.proto";
 
 
 // Runs alongside each trainer to perform gating function averaging every now and then. Read more: client/averaging.py
 // Runs alongside each trainer to perform gating function averaging every now and then. Read more: client/averaging.py
 service DecentralizedAveraging {
 service DecentralizedAveraging {
-  rpc rpc_group_allreduce(PeerInfo) returns (stream MessageFromLeader);  // assemble a group and run all-reduce
+  rpc rpc_join_group(JoinRequest) returns (stream MessageFromLeader);  // assemble a group for allreduce
   rpc rpc_aggregate_part(AveragingData) returns (AveragingData);  // send my local shard => get aggregated shard
   rpc rpc_aggregate_part(AveragingData) returns (AveragingData);  // send my local shard => get aggregated shard
 }
 }
 
 
-message PeerInfo {
+enum MessageCode {
+  NO_CODE = 0;               // Default value that should not be used explicitly
+  REQUEST_JOIN = 1;          // "Dear maybe leader, will you have me in your group as a follower?"
+  ACCEPTED = 2;              // "I accept you in my group, you now commit to responding to me"
+  BEGIN_ALLREDUCE = 3;       // "We can begin allreduce now. These are your peers."
+  PART_FOR_AVERAGING = 4;    // "I am running allreduce with you, here's a part of my tensor that you should aggregate"
+  AVERAGED_PART = 5;         // "I aggregated your part with others and here's the average for that part"
+  NOT_DECLARED = 6;          // "I have not declared my group id yet, how the heck did you even find me? Go away."
+  NOT_A_LEADER = 7;          // "I am not a group a leader. Go ask my leader instead."
+  BAD_EXPIRATION_TIME = 8;   // "I will not accept you. I cannot guarantee that we begin before you expire."
+  BAD_SCHEMA_HASH = 9;       // "I will not accept you. I am not averaging the samy type of tensors as you."
+  BAD_GROUP_ID = 10;         // "I will not accept your request, your group id does not match with any groups i'm in."
+  DUPLICATE_ENDPOINT = 11;   // "I will not accept you, i already have exactly the same endpoint in my current group."
+  GROUP_IS_FULL = 12;        // "I will not accept you, my group already contains too many peers."
+  NOT_LOOKING_FOR_GROUP = 13;// "I'm not available at the moment. Please, get lost."
+  PROTOCOL_VIOLATION = 14;   // "You did something so unspeakable that i don't have a special code for that."
+  INTERNAL_ERROR = 15;       // "I messed up, we will have to stop allreduce because of that."
+  CANCELLED = 16;            // "[from peer during allreduce] I no longer want to participate in AllReduce."
+  GROUP_DISBANDED = 17;      // "[from leader] The group is closed. Go find another group."
+}
+
+message JoinRequest {
   string endpoint = 1;          // A follower accepts incoming allreduce requests at this address
   string endpoint = 1;          // A follower accepts incoming allreduce requests at this address
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
 }
 }
 
 
-enum MessageCode {
-  // response to join request
-  ACCEPTED = 0;              // "I accept you in my group, you will not commit to responding to me."
-  NOT_A_LEADER = 1;          // "I am not a group a leader. Go ask my leader instead."
-  ALREADY_RUNNING = 2;       // "My group has already began merging. Here's the group leader."
-  NOT_LOOKING_FOR_GROUP = 3; // "I'm not available at the moment. Please, get lost."
-  BAD_EXPIRATION_TIME = 4;   // "I will not accept you. I cannot guarantee that we begin before you expire."
-  BAD_SCHEMA_HASH = 5;       // "I will not accept you. I am not averaging the samy type of tensors as you."
-  DUPLICATE_ENDPOINT = 6;    // "I will not accept you, i already have exactly the same endpoint in my current group"
-  GROUP_IS_FULL = 7;         // "I will not accept you, my group already contains too many peers"
-  BEGIN_ALLREDUCE = 8;       // "We can begin allreduce now. These are your peers."
-  GROUP_DISBANDED = 9;       // "The group is closed. Go find another group."
-  UNKNOWN_GROUP_ID = 10;     // "Your request uses a group id that doesn't match with any group i know"
-  PROTOCOL_VIOLATION = 11;   // "One of peers did something in violation of the allreduce protocol"
-  INTERNAL_ERROR = 12;       // "We encountered an unexpected error on our side"
-  CANCELLED = 13;            // "A peer cancelled allreduce while averaging"
-}
-
 message MessageFromLeader {
 message MessageFromLeader {
   MessageCode code = 1;
   MessageCode code = 1;
   bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
   bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed

+ 1 - 0
hivemind/utils/timed_storage.py

@@ -8,6 +8,7 @@ from typing import TypeVar, NamedTuple, Generic, Optional, Dict, List, Iterator,
 KeyType = TypeVar('KeyType')
 KeyType = TypeVar('KeyType')
 ValueType = TypeVar('ValueType')
 ValueType = TypeVar('ValueType')
 get_dht_time = time.time  # a global (weakly synchronized) time
 get_dht_time = time.time  # a global (weakly synchronized) time
+MAX_DHT_TIME_DISCREPANCY_SECONDS = 3  # max allowed difference between get_dht_time for two DHT nodes. Enforced when joining DHT.(TODO)
 DHTExpiration = float
 DHTExpiration = float
 ROOT = 0
 ROOT = 0
 
 

+ 57 - 55
tests/test_averaging.py

@@ -1,46 +1,62 @@
 import asyncio
 import asyncio
 import random
 import random
 import time
 import time
-from itertools import product
 
 
 import torch
 import torch
 import pytest
 import pytest
 import hivemind
 import hivemind
-from hivemind.client.allreduce import GroupAllReduce, split_into_parts, restore_from_parts
-from hivemind.utils import LOCALHOST
+from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
+from hivemind.utils import Endpoint
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-@pytest.mark.asyncio
-async def test_allreduce_direct():
-    # WARNING! this test uses an early interface that will change by the time DecentralizedAverager is finished
+def test_getset_averagers():
+    dht = hivemind.DHT(start=True)
+
+    t = hivemind.get_dht_time()
+    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 60)
+    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', expiration_time=t + 61)
+
+    q1 = dht.get_averagers('bucket.0b10110', only_active=True)
+
+    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 66)
+    q2 = dht.get_averagers('bucket.0b10110', only_active=True)
+
+    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', looking_for_group=False,
+                         expiration_time=t + 61)
+    q3 = dht.get_averagers('bucket.0b10110', only_active=True)
+    q4 = dht.get_averagers('bucket.0b10110', only_active=False)
+
+    assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1
+    assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2
+    assert len(q3) == 1 and ('localhvost', t + 66) in q3
+    assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2
+
 
 
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_allreduce_once():
     dht = hivemind.DHT(start=True)
     dht = hivemind.DHT(start=True)
 
 
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
+    tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
 
 
-    reference = [(tensors1[i] + tensors2[i] + tensors3[i]) / 3 for i in range(len(tensors1))]
-
-    averager1 = hivemind.DecentralizedAverager(tensors1, dht=dht, start=True, max_size=3, timeout=5)
-    averager2 = hivemind.DecentralizedAverager(tensors2, dht=dht, start=True, max_size=3, timeout=5)
-    averager3 = hivemind.DecentralizedAverager(tensors3, dht=dht, start=True, max_size=3, timeout=5)
-
-    future1 = averager1.group_allreduce(my_endpoint=f"{LOCALHOST}:{averager1.port}",
-                                        leader_endpoint=None, return_future=True)
-    time.sleep(0.1)
+    reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
 
 
-    future2 = averager2.group_allreduce(my_endpoint=f"{LOCALHOST}:{averager2.port}",
-                                        leader_endpoint=f"{LOCALHOST}:{averager1.port}",
-                                        return_future=True)
+    averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
+                                                prefix='mygroup', initial_group_bits='0110', listen_on='127.0.0.1:*',
+                                                start=True)
+                 for tensors in [tensors1, tensors2, tensors3, tensors4]]
 
 
-    future3 = averager3.group_allreduce(my_endpoint=f"{LOCALHOST}:{averager3.port}",
-                                        leader_endpoint=f"{LOCALHOST}:{averager1.port}",
-                                        return_future=True)
+    futures = []
+    for averager in averagers:
+        futures.append(averager.step(return_future=True))  # TODO revert to hard version
+        time.sleep(0.5)
 
 
-    for future in future1, future2, future3:
-        for ref, our in zip(reference, await future):
+    for future in futures:
+        for ref, our in zip(reference, future.result()):
             assert torch.allclose(ref, our)
             assert torch.allclose(ref, our)
 
 
 
 
@@ -49,50 +65,36 @@ async def test_allreduce_direct():
 async def test_allreduce_protocol():
 async def test_allreduce_protocol():
     """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
     """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
     peers = "alice", "bob", "carol"
     peers = "alice", "bob", "carol"
-    expiration_offsets = 4, 0, 1
 
 
     tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
     tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
                        for i, peer in enumerate(peers)}
                        for i, peer in enumerate(peers)}
 
 
-    alice, bob, carol = allreduce_protocols = [
-        GroupAllReduce(endpoint=peer, expiration=hivemind.get_dht_time() + offset, tensors=tensors_by_peer[peer])
-        for peer, offset in zip(peers, expiration_offsets)]
-
-    bob.start_new_group()
-    bob.add_peer_to_group(alice.info.endpoint)
-    alice.join_group(bob, bob.group_id)
-    bob.add_peer_to_group(carol.info.endpoint)
-    carol.join_group(carol, bob.group_id)
-
-    bob.leader_begin_allreduce()
-    ordered_group_endpoints = await bob.assembled_group
-    assert len(ordered_group_endpoints) == len(peers)
-
-    carol.follower_begin_allreduce(ordered_group_endpoints)
-    alice.follower_begin_allreduce(ordered_group_endpoints)
-
-    chunks_by_peer = {protocol.info.endpoint: {
-        peer: part for peer, part in zip(peers, split_into_parts(protocol.local_tensors, len(ordered_group_endpoints)))
-    } for protocol in allreduce_protocols}
+    group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
+    allreduce_protocols = [AllReduceProtocol(
+        group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer], ordered_group_endpoints=peers)
+        for peer in peers]
 
 
-    all_pairs = list(product(allreduce_protocols, peers))
-    random.shuffle(all_pairs)
-    await asyncio.gather(*(
-        peer_allreduce.accumulate(source_peer, chunks_by_peer[source_peer][peer_allreduce.info.endpoint])
-        for peer_allreduce, source_peer in all_pairs))
+    async def _accumulate(sender: Endpoint, recipient: Endpoint):
+        sender_allreduce = allreduce_protocols[peers.index(sender)]
+        recipient_allreduce = allreduce_protocols[peers.index(recipient)]
+        averaged_part = await recipient_allreduce.accumulate_part(
+            source=sender, remote_part=sender_allreduce.local_tensor_parts[recipient])
+        sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part)
 
 
-    averaged_parts = await asyncio.gather(*(protocol.averaged_part for protocol in allreduce_protocols))
-    tensor_shapes = [tensor.shape for tensor in alice.local_tensors]
-    averaged_tensors = restore_from_parts(averaged_parts, tensor_shapes)
+    await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers
+                        if sender != recipient})
 
 
     reference_tensors = [
     reference_tensors = [
         sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
         sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
         for i in range(len(tensors_by_peer[peers[0]]))
         for i in range(len(tensors_by_peer[peers[0]]))
     ]
     ]
 
 
-    assert len(averaged_tensors) == len(reference_tensors)
-    assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
-               for our, ref in zip(averaged_tensors, reference_tensors))
+    for peer, allreduce in zip(peers, allreduce_protocols):
+        assert allreduce.averaged_tensors.done()
+        averaged_tensors = await allreduce
+        assert len(averaged_tensors) == len(reference_tensors)
+        assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
+                   for our, ref in zip(averaged_tensors, reference_tensors))
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked