Jelajahi Sumber

Averager: update group keys after every step, infer nbits dynamically (#141)

justheuristic 4 tahun lalu
induk
melakukan
10917b259e

+ 6 - 1
docs/modules/client.rst

@@ -16,4 +16,9 @@
 
 
 .. autoclass:: RemoteMixtureOfExperts
 .. autoclass:: RemoteMixtureOfExperts
    :members:
    :members:
-   :member-order: bysource
+   :member-order: bysource
+
+.. autoclass:: DecentralizedAverager
+   :members:
+   :member-order: bysource
+   :exclude-members: get_tensors, update_tensors, rpc_join_group, rpc_aggregate_part

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.26'
+__version__ = '0.8.27'

+ 20 - 22
hivemind/client/averaging/__init__.py

@@ -6,7 +6,6 @@ import asyncio
 import contextlib
 import contextlib
 import ctypes
 import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
-import random
 from concurrent.futures.thread import ThreadPoolExecutor
 from concurrent.futures.thread import ThreadPoolExecutor
 from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 
 
@@ -16,7 +15,7 @@ import numpy as np
 
 
 import hivemind
 import hivemind
 from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
 from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
-from hivemind.client.averaging.matchmaking import Matchmaking
+from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
 from hivemind.utils import get_logger, Endpoint, Port, MPFuture, GRPC_KEEPALIVE_OPTIONS, get_dht_time, MSGPackSerializer
 from hivemind.utils import get_logger, Endpoint, Port, MPFuture, GRPC_KEEPALIVE_OPTIONS, get_dht_time, MSGPackSerializer
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
@@ -24,7 +23,6 @@ from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 # flavour types
 # flavour types
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 
 
-INITIAL_GROUP_NBITS = 3
 DataForGather = Any
 DataForGather = Any
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -43,8 +41,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
     :param prefix: a shared prefix for all group keys
     :param prefix: a shared prefix for all group keys
     :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
     :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
-    :param initial_group_bits: a string of bits ('0' and '1') that define initial group key (bucket index)
-      by default, sample a random bit sequence of length {INITIAL_GROUP_NBITS}
+    :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
     :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
     :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
       note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
       note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
     :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
     :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
@@ -56,7 +53,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
     :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
     :param throughput: if specified, this value represents the network bandwidth available to averager.
     :param throughput: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
           By default, the averager is assumed to have the average bandwidth of his group.
-          If throughput == 0, averager will run in client-only mode (TODO not implemented yet!)
+          If throughput == 0, averager will rely on its groupmates to do all the averaging.
     :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
     :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
             if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
             if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
@@ -64,9 +61,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
     :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
     :param kwargs: extra parameters forwarded to grpc.aio.server
     :param kwargs: extra parameters forwarded to grpc.aio.server
-    You can perform averaging using DecentralizedOptimizer (see below) or by manually running each step as such:
 
 
-    >> TODO add a working example here
+    Example:
+
+    >>> averager = DecentralizedAverager(...)
+    >>> with averager.get_tensors() as tensors:
+    >>>     # run some code, modify tensors if necessary
+    >>>     tensors[0] += 1
+    >>> # do not use tensors after the lock is released
+    >>> metadata = averager.step(gather=dict(my_batch_size=32))
+    >>> # run averaging once (in-place), gather metadata from groupmates
+    >>> with averager.get_tensors() as tensors_after_averaging:
+    >>>     pass # use the averaged tensors
     """
     """
     _matchmaking: Matchmaking
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
     _pending_group_assembled: asyncio.Event
@@ -81,16 +87,13 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                  listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
                  listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
-        assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), "throughput must be a" \
-                                                                                                " nonnegative float32"
+        assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
+            "throughput must be a non-negative float32"
         if not listen:
         if not listen:
             raise NotImplementedError("Client-only averaging is not implemented yet.")
             raise NotImplementedError("Client-only averaging is not implemented yet.")
         if not is_power_of_two(target_group_size):
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
             logger.warning("It is recommended to set target_group_size to a power of 2.")
-        if initial_group_bits is None:
-            initial_group_bits = ''.join(random.choices('01', k=INITIAL_GROUP_NBITS))
-            logger.debug(f"Initializing with random {INITIAL_GROUP_NBITS}-bit group index: {initial_group_bits}")
-        assert len(initial_group_bits) >= INITIAL_GROUP_NBITS and all(bit in '01' for bit in initial_group_bits)
+        assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
 
 
         super().__init__()
         super().__init__()
         self.dht = dht
         self.dht = dht
@@ -178,7 +181,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         else:
         else:
             logger.warning("DHT shutdown has no effect: the process is not alive")
             logger.warning("DHT shutdown has no effect: the process is not alive")
 
 
-    def step(self, allow_retries: bool = True, gather: Optional[DataForGather] = None, timeout: Optional[float] = None,
+    def step(self, gather: Optional[DataForGather] = None, allow_retries: bool = True, timeout: Optional[float] = None,
              wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
              wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
         """
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
@@ -219,7 +222,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
                 gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
                 future.set_result(gathered_data_by_peer)
                 future.set_result(gathered_data_by_peer)
 
 
-            except AllreduceException:
+            except (AllreduceException, MatchmakingException):
                 time_elapsed = get_dht_time() - start_time
                 time_elapsed = get_dht_time() - start_time
                 if not allow_retries or (timeout is not None and timeout < time_elapsed):
                 if not allow_retries or (timeout is not None and timeout < time_elapsed):
                     future.set_result(None)
                     future.set_result(None)
@@ -248,12 +251,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """
         """
         A contextmanager that gives user access to averaged tensors.
         A contextmanager that gives user access to averaged tensors.
         It is guaranteed that the averager will not modify tensors while this context is active.
         It is guaranteed that the averager will not modify tensors while this context is active.
-
-        Example:
-              >>> with averager.get_tensors() as tensors:
-              >>>     update_model(tensors)
-              >>>     tensors[0] += 1
-              >>> # do not use tensors after the lock is acquired
+        Please do not modify the yielded tensors in-place after the context is released.
         """
         """
         with self.lock_averaged_tensors:
         with self.lock_averaged_tensors:
             yield self._averaged_tensors
             yield self._averaged_tensors

+ 2 - 1
hivemind/client/averaging/allreduce.py

@@ -123,12 +123,13 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
 
 
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
                  ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
                  ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
-                 chunk_size_bytes: int, part_sizes: Tuple[int, ...], gathered: Sequence[Any] = (),
+                 chunk_size_bytes: int, part_sizes: Tuple[int, ...], group_key_seed: int, gathered: Sequence[Any] = (),
                  return_deltas: bool = False):
                  return_deltas: bool = False):
         super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
         super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
                          ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
                          ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
         self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
         self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
         self.averaged_part_stream: asyncio.Future[Tuple[runtime_pb2.Tensor, ...]] = asyncio.Future()
         self.averaged_part_stream: asyncio.Future[Tuple[runtime_pb2.Tensor, ...]] = asyncio.Future()
+        self.group_key_seed = group_key_seed
 
 
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)

+ 142 - 0
hivemind/client/averaging/key_manager.py

@@ -0,0 +1,142 @@
+import asyncio
+import re
+import random
+from typing import Optional, List, Tuple
+
+import numpy as np
+
+from hivemind.dht import DHT
+from hivemind.client.averaging.allreduce import AllReduceRunner
+from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration
+
+GroupKey = str
+GROUP_PATTERN = re.compile('^(([^.])+)[.]0b[01]*$')  # e.g. bert_exp4_averaging.0b01001101
+logger = get_logger(__name__)
+
+
+def is_valid_group(maybe_group: str) -> bool:
+    """ A group identifier must contain group type, followed by one or more .-separated indices, and any ?metadata"""
+    return bool(GROUP_PATTERN.fullmatch(maybe_group))
+
+
+class GroupKeyManager:
+    """
+    Utility class that declares and fetches averaging-related keys using a DHT
+    """
+    RESERVED_KEY_FOR_NBITS = '::NBITS'
+
+    def __init__(self, dht: DHT, endpoint: Endpoint, prefix: str, initial_group_bits: Optional[str],
+                 target_group_size: int, insufficient_size: Optional[int] = None, excessive_size: Optional[int] = None,
+                 nbits_expiration: float = 60):
+        assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
+        if initial_group_bits is None:
+            search_result = dht.get(f"{prefix}.0b", latest=True)
+            initial_group_bits = self.get_suggested_nbits(search_result) or ''
+        self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
+        self.target_group_size = target_group_size
+        self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
+        self.excessive_size = excessive_size or target_group_size * 3
+        self.nbits_expiration = nbits_expiration
+        self.suggested_nbits: Optional[int] = None
+
+    @property
+    def current_key(self) -> GroupKey:
+        return f"{self.prefix}.0b{self.group_bits}"
+
+    async def declare_averager(self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float,
+                               looking_for_group: bool = True) -> bool:
+        """
+        Add (or remove) the averager to a given allreduce bucket
+
+        :param group_key: allreduce group key, e.g. my_averager.0b011011101
+        :param endpoint: averager public endpoint for incoming requests
+        :param expiration_time: intent to run allreduce before this timestamp
+        :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
+          If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
+        :return: True if declared, False if declaration was rejected by DHT peers
+        :note: when leaving (i.e. is_active=False), please specify the same expiration_time as when entering the group
+        :note: setting is_active=False does *not* guarantee that others will immediately stop to query you.
+        """
+        expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float('inf')))
+        return await self.dht.store(key=group_key, subkey=endpoint, value=looking_for_group,
+                                    expiration_time=expiration_time, return_future=True)
+
+    async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tuple[Endpoint, DHTExpiration]]:
+        """
+        Find and return averagers that were declared with a given all-reduce key
+
+        :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
+        :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
+            if False, return all averagers under a given group_key regardless of value
+        :return: endpoints and expirations of every matching averager
+        """
+        assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
+        result = await self.dht.get(group_key, latest=True, return_future=True)
+        if result is None or not isinstance(result.value, dict):
+            logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
+            return []
+        averagers = [(key, entry.expiration_time) for key, entry in result.value.items()
+                     if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)]
+        num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
+
+        suggested_nbits = self.get_suggested_nbits(result)
+        if suggested_nbits is not None and suggested_nbits != self.suggested_nbits:
+            self.suggested_nbits = suggested_nbits
+            logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
+        elif num_active_averagers >= self.excessive_size:
+            self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
+            logger.warning(f"{self.endpoint} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
+        return averagers
+
+    async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
+        """ notify other peers that they can run averaging at this depth """
+        return await self.dht.store(key=group_key, subkey=self.RESERVED_KEY_FOR_NBITS, value=nbits,
+                                    expiration_time=expiration_time, return_future=True)
+
+    @classmethod
+    def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
+        if isinstance(search_result, ValueWithExpiration) and cls.RESERVED_KEY_FOR_NBITS in search_result.value \
+                and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int):
+            return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
+        else:
+            return None
+
+    async def update_key_on_group_assembled(self, allreduce_group: AllReduceRunner, is_leader: bool = True):
+        """ this function is triggered every time an averager finds an allreduce group """
+        rng = random.Random(allreduce_group.group_key_seed)
+        index = allreduce_group.ordered_group_endpoints.index(self.endpoint)
+        generalized_index = rng.sample(range(self.target_group_size), allreduce_group.group_size)[index]
+        nbits = int(np.ceil(np.log2(self.target_group_size)))
+        new_bits = bin(generalized_index)[2:].rjust(nbits, '0')
+        self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits):]
+        logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
+
+        if is_leader and self.insufficient_size < allreduce_group.group_size < self.excessive_size:
+            asyncio.create_task(self.notify_stragglers_on_success())
+        if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
+            num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
+            self.group_bits = ''.join((random.choice('01') for _ in range(num_extra_bits))) + self.group_bits
+            self.group_bits = self.group_bits[-self.suggested_nbits:]
+        self.suggested_nbits = None
+
+    async def update_key_on_not_enough_peers(self):
+        """ this function is triggered whenever averager fails to assemble group within timeout """
+        new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
+        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:]
+        if self.group_bits != prev_nbits:
+            logger.warning(f'{self.endpoint} - switching to {len(self.group_bits)}-bit keys')
+        self.suggested_nbits = None
+
+    async def notify_stragglers_on_success(self):
+        """ Find averagers that have fewer nbits and redirect them to your current nbits """
+        for nbits in reversed(range(1, len(self.group_bits) - 1)):
+            preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
+            preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
+
+            if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
+                await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
+                break
+
+        root_data = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True)
+        if root_data is None or self.RESERVED_KEY_FOR_NBITS not in root_data.value:
+            await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)

+ 61 - 54
hivemind/client/averaging/matchmaking.py

@@ -12,14 +12,15 @@ import asyncio
 import grpc
 import grpc
 import torch
 import torch
 
 
-import hivemind
 from hivemind.client.averaging.allreduce import AllReduceRunner
 from hivemind.client.averaging.allreduce import AllReduceRunner
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.load_balancing import load_balance_peers
-from hivemind.dht import DHTID, DHTExpiration, get_dht_time, GroupKey
-from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, TimedStorage
+from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
+from hivemind.dht import DHT, DHTID, DHTExpiration, get_dht_time
+from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, timed_storage, TimedStorage
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc
 from hivemind.utils.grpc import ChannelCache
 from hivemind.utils.grpc import ChannelCache
 
 
+
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
@@ -36,7 +37,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
       Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
       Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
     """
     """
 
 
-    def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *,
+    def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *,
                  prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
                  prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
                  averaging_expiration: float = 15, request_timeout: float, throughput: Optional[float] = None,
                  averaging_expiration: float = 15, request_timeout: float, throughput: Optional[float] = None,
                  min_vector_size: int, **allreduce_kwargs):
                  min_vector_size: int, **allreduce_kwargs):
@@ -46,12 +47,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                            "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring.")
                            "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring.")
 
 
         super().__init__()
         super().__init__()
-        self.dht, self.endpoint, self.averaged_tensors = dht, endpoint, tuple(averaged_tensors)
-        self.prefix, self.group_bits = prefix, initial_group_bits
+        self.endpoint, self.averaged_tensors = endpoint, tuple(averaged_tensors)
+        self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
-        self.throughput, self.min_vector_size = throughput, min_vector_size
-        self.allreduce_kwargs = allreduce_kwargs
+        self.throughput, self.min_vector_size, self.allreduce_kwargs = throughput, min_vector_size, allreduce_kwargs
         self.schema_hash = compute_schema_hash(self.averaged_tensors)
         self.schema_hash = compute_schema_hash(self.averaged_tensors)
         self.total_size = sum(tensor.numel() for tensor in self.averaged_tensors)
         self.total_size = sum(tensor.numel() for tensor in self.averaged_tensors)
 
 
@@ -63,17 +63,13 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
         self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
         self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
         self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(endpoint, dht, averaging_expiration, target_group_size)
-        self.data_for_gather: bytes = None
+        self.potential_leaders = PotentialLeaders(endpoint, averaging_expiration, target_group_size)
+        self.data_for_gather: Optional[bytes] = None
 
 
     @property
     @property
     def is_looking_for_group(self):
     def is_looking_for_group(self):
         return self.lock_looking_for_group.locked()
         return self.lock_looking_for_group.locked()
 
 
-    @property
-    def current_group_key(self) -> GroupKey:
-        return f"{self.prefix}.0b{self.group_bits}"
-
     def __repr__(self):
     def __repr__(self):
         lfg_status = "looking for group," if self.is_looking_for_group else "not looking for group,"
         lfg_status = "looking for group," if self.is_looking_for_group else "not looking for group,"
         if self.is_looking_for_group:
         if self.is_looking_for_group:
@@ -83,12 +79,12 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 lfg_status += f" leading {len(self.current_followers)} followers,"
                 lfg_status += f" leading {len(self.current_followers)} followers,"
         schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
         schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
         return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
         return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
-               f" current key = {self.current_group_key})"
+               f" current key = {self.group_key_manager.current_key})"
 
 
     async def look_for_group(self, *, data_for_gather: bytes = b'', timeout: Optional[float] = None
     async def look_for_group(self, *, data_for_gather: bytes = b'', timeout: Optional[float] = None
                              ) -> Optional[AllReduceRunner]:
                              ) -> Optional[AllReduceRunner]:
         """
         """
-        :param gather: optionally send this data to all peers in the next group and gather it from every groupmate
+        :param data_for_gather: optionally send this data to all peers in the next group and gather it from groupmates
         :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
         :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
@@ -127,10 +123,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
         """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
         """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
-        async with self.potential_leaders.begin_search(self.current_group_key, timeout):
-            # TODO update group_bits on success! reduce number of bits on not enough peers.
-            # TODO after allreduce finishes, we may need to ask leader to notify lower keys about this
-            # (so as to fix possible network partitioning if some peers operate on a much smaller nbits)
+        async with self.potential_leaders.begin_search(self.group_key_manager, timeout):
             while True:
             while True:
                 try:
                 try:
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
@@ -148,7 +141,6 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                             return await self.leader_assemble_group()
                             return await self.leader_assemble_group()
                         elif len(self.current_followers) > 0:
                         elif len(self.current_followers) > 0:
                             await self.leader_disband_group()
                             await self.leader_disband_group()
-                            # TODO maybe adjust grid size
                         continue
                         continue
                 except Exception as e:
                 except Exception as e:
                     if not self.assembled_group.done():
                     if not self.assembled_group.done():
@@ -262,8 +254,8 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             allreduce_group = self.assembled_group.result()
             allreduce_group = self.assembled_group.result()
             yield averaging_pb2.MessageFromLeader(
             yield averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
                 code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
-                ordered_group_endpoints=allreduce_group.ordered_group_endpoints,
-                part_sizes=allreduce_group.part_sizes, gathered=allreduce_group.gathered)
+                ordered_group_endpoints=allreduce_group.ordered_group_endpoints, part_sizes=allreduce_group.part_sizes,
+                gathered=allreduce_group.gathered, group_key_seed=allreduce_group.group_key_seed)
 
 
         except Exception as e:
         except Exception as e:
             logger.exception(e)
             logger.exception(e)
@@ -319,11 +311,13 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 gathered.append(follower_info.gather if follower_info.gather else None)
                 gathered.append(follower_info.gather if follower_info.gather else None)
 
 
         part_sizes = load_balance_peers(self.total_size, throughputs, self.min_vector_size)
         part_sizes = load_balance_peers(self.total_size, throughputs, self.min_vector_size)
+        group_key_seed = random.randint(- 2 ** 31, 2 ** 31 - 1)
 
 
         logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
         logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
                                           ordered_group_endpoints=ordered_group_endpoints, part_sizes=part_sizes,
                                           ordered_group_endpoints=ordered_group_endpoints, part_sizes=part_sizes,
-                                          gathered=gathered, **self.allreduce_kwargs)
+                                          gathered=gathered, group_key_seed=group_key_seed, **self.allreduce_kwargs)
+        await self.group_key_manager.update_key_on_group_assembled(allreduce_group, is_leader=True)
         self.assembled_group.set_result(allreduce_group)
         self.assembled_group.set_result(allreduce_group)
         return allreduce_group
         return allreduce_group
 
 
@@ -340,7 +334,9 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
         logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
                                           ordered_group_endpoints=tuple(ordered_group_endpoints),
                                           ordered_group_endpoints=tuple(ordered_group_endpoints),
-                                          part_sizes=tuple(part_sizes), gathered=msg.gathered, **self.allreduce_kwargs)
+                                          part_sizes=tuple(part_sizes), gathered=msg.gathered,
+                                          group_key_seed=int(msg.group_key_seed), **self.allreduce_kwargs)
+        await self.group_key_manager.update_key_on_group_assembled(allreduce_group)
         self.assembled_group.set_result(allreduce_group)
         self.assembled_group.set_result(allreduce_group)
         return allreduce_group
         return allreduce_group
 
 
@@ -353,9 +349,8 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 class PotentialLeaders:
 class PotentialLeaders:
     """ An utility class that searches for averagers that could become our leaders """
     """ An utility class that searches for averagers that could become our leaders """
 
 
-    def __init__(self, endpoint: Endpoint, dht: hivemind.DHT, averaging_expiration: DHTExpiration,
-                 target_group_size: Optional[int]):
-        self.endpoint, self.dht, self.averaging_expiration = endpoint, dht, averaging_expiration
+    def __init__(self, endpoint: Endpoint, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
+        self.endpoint, self.averaging_expiration = endpoint, averaging_expiration
         self.target_group_size = target_group_size
         self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
@@ -367,12 +362,12 @@ class PotentialLeaders:
         self.search_end_time = float('inf')
         self.search_end_time = float('inf')
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
-    async def begin_search(self, group_key: GroupKey, timeout: Optional[float]):
+    async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float]):
         async with self.lock_search:
         async with self.lock_search:
             self.running.set()
             self.running.set()
             self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
             self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
-            update_queue_task = asyncio.create_task(self._update_queue_periodically(group_key))
-            declare_averager_task = asyncio.create_task(self._declare_averager_periodically(group_key))
+            update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
+            declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
             try:
             try:
                 yield self
                 yield self
             finally:
             finally:
@@ -429,38 +424,46 @@ class PotentialLeaders:
         else:
         else:
             return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
             return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
 
 
-    async def _update_queue_periodically(self, group_key: GroupKey):
-        DISCREPANCY = hivemind.utils.timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
-        while get_dht_time() < self.search_end_time:
-            new_peers = await self.dht.get_averagers(group_key, only_active=True, return_future=True)
-            self.max_assured_time = max(self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY)
-
-            self.leader_queue.clear()
-            for peer, peer_expiration_time in new_peers:
-                if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
-                    continue
-                self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
-                self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
+    async def _update_queue_periodically(self, key_manager: GroupKeyManager):
+        try:
+            DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
+            while get_dht_time() < self.search_end_time:
+                new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
+                self.max_assured_time = max(self.max_assured_time,
+                                            get_dht_time() + self.averaging_expiration - DISCREPANCY)
+
+                self.leader_queue.clear()
+                for peer, peer_expiration_time in new_peers:
+                    if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
+                        continue
+                    self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
+                    self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
 
 
-            self.update_finished.set()
+                self.update_finished.set()
 
 
-            await asyncio.wait(
-                {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
-                timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None)
-            self.update_triggered.clear()
+                await asyncio.wait(
+                    {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
+                    timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None)
+                self.update_triggered.clear()
+        except Exception as e:
+            logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
+            raise
 
 
-    async def _declare_averager_periodically(self, group_key: GroupKey):
+    async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
         async with self.lock_declare:
         async with self.lock_declare:
             try:
             try:
                 while True:
                 while True:
                     await self.running.wait()
                     await self.running.wait()
 
 
                     new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
                     new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
-                    self.declared_group_key, self.declared_expiration_time = group_key, new_expiration_time
+                    self.declared_group_key = group_key = key_manager.current_key
+                    self.declared_expiration_time = new_expiration_time
                     self.declared_expiration.set()
                     self.declared_expiration.set()
-                    await self.dht.declare_averager(group_key, self.endpoint, new_expiration_time,
-                                                    looking_for_group=True, return_future=True)
+                    await key_manager.declare_averager(group_key, self.endpoint, expiration_time=new_expiration_time)
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
+                    if self.running.is_set() and len(self.leader_queue) == 0:
+                        await key_manager.update_key_on_not_enough_peers()
+
             except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
             except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
                 logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
                 logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
             finally:
             finally:
@@ -468,8 +471,8 @@ class PotentialLeaders:
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
                     self.declared_group_key, self.declared_expiration_time = None, float('inf')
                     self.declared_group_key, self.declared_expiration_time = None, float('inf')
                     self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float('-inf')
                     self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float('-inf')
-                    await self.dht.declare_averager(prev_declared_key, self.endpoint, prev_expiration_time,
-                                                    looking_for_group=False, return_future=True)
+                    await key_manager.declare_averager(prev_declared_key, self.endpoint, prev_expiration_time,
+                                                       looking_for_group=False)
 
 
 
 
 def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
 def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
@@ -478,3 +481,7 @@ def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
                      for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
                      for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
                     for tensor in tensors]
                     for tensor in tensors]
     return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()
     return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()
+
+
+class MatchmakingException(Exception):
+    """ An internal exception that marks undesired edge cases during averaging """

+ 51 - 77
hivemind/dht/__init__.py

@@ -12,6 +12,7 @@ The code is organized as follows:
 - [1] Maymounkov P., Mazieres D. (2002) Kademlia: A Peer-to-Peer Information System Based on the XOR Metric.
 - [1] Maymounkov P., Mazieres D. (2002) Kademlia: A Peer-to-Peer Information System Based on the XOR Metric.
 - [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :)
 - [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :)
 """
 """
+from __future__ import annotations
 import asyncio
 import asyncio
 import ctypes
 import ctypes
 import heapq
 import heapq
@@ -21,12 +22,11 @@ from collections import deque
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
 from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
 
 
-from numpy import nextafter
 
 
 from hivemind.client import RemoteExpert
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
-from hivemind.dht.routing import get_dht_time, DHTValue
-from hivemind.utils import MPFuture, Endpoint, Hostname, get_logger, switch_to_uvloop, strip_port
+from hivemind.dht.routing import get_dht_time, DHTValue, DHTKey, Subkey
+from hivemind.utils import MPFuture, Endpoint, Hostname, get_logger, switch_to_uvloop, strip_port, ValueWithExpiration
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -37,8 +37,6 @@ FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to spe
 UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
 UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
 PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
 PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
 #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
 #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
-GroupKey = str
-GROUP_PATTERN = re.compile('^(([^.])+)[.]0b[01]+$')  # e.g. bert_exp4_averaging.0b01001101
 
 
 
 
 def is_valid_uid(maybe_uid: str) -> bool:
 def is_valid_uid(maybe_uid: str) -> bool:
@@ -50,12 +48,6 @@ def is_valid_prefix(maybe_prefix: str) -> bool:
     """ An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period """
     """ An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period """
     return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
     return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
 
 
-
-def is_valid_group(maybe_group: str) -> bool:
-    """ A group identifier must contain group type, followed by one or more .-separated indices, and any ?metadata"""
-    return bool(GROUP_PATTERN.fullmatch(maybe_group))
-
-
 def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
 def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
     """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
     """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
@@ -180,6 +172,54 @@ class DHT(mp.Process):
     def port(self) -> Optional[int]:
     def port(self) -> Optional[int]:
         return self._port.value if self._port.value != 0 else None
         return self._port.value if self._port.value != 0 else None
 
 
+    def get(self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
+            ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
+        """
+        Search for a key across DHT and return either first or latest entry (if found).
+        :param key: same key as in node.store(...)
+        :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :param kwargs: parameters forwarded to DHTNode.get_many_by_id
+        :returns: (value, expiration time); if value was not found, returns None
+        """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_get', [], dict(key=key, latest=latest, future=_future, **kwargs)))
+        return future if return_future else future.result()
+
+    async def _get(self, node: DHTNode, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
+        try:
+            result = await node.get(key, latest=latest, **kwargs)
+            if not future.done():
+                future.set_result(result)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+            raise
+
+    def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
+              subkey: Optional[Subkey] = None, return_future: bool = False, **kwargs) -> Union[bool, MPFuture]:
+        """
+        Find num_replicas best nodes to store (key, value) and store it there until expiration time.
+        :note: store is a simplified interface to store_many, all kwargs are be forwarded there
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: True if store succeeds, False if it fails (due to no response or newer value)
+        """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey,
+                                           future=_future, **kwargs)))
+        return future if return_future else future.result()
+
+    async def _store(self, node: DHTNode, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
+                     subkey: Optional[Subkey], future: MPFuture, **kwargs):
+        try:
+            result = await node.store(key, value, expiration_time, subkey=subkey, **kwargs)
+            if not future.done():
+                future.set_result(result)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+            raise
+
     def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
     def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
         """
         """
         Get this machine's visible address by requesting other peers or using pre-specified network addresses.
         Get this machine's visible address by requesting other peers or using pre-specified network addresses.
@@ -519,69 +559,3 @@ class DHT(mp.Process):
         if future is not None:
         if future is not None:
             future.set_result(best_experts_batch)
             future.set_result(best_experts_batch)
         return best_experts_batch
         return best_experts_batch
-
-    def declare_averager(self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float, *,
-                         looking_for_group: bool = True, return_future: bool = False) -> Union[bool, MPFuture]:
-        """
-        Add (or remove) the averager to a given allreduce bucket
-
-        :param group_key: allreduce group key, e.g. my_averager.0b011011101
-        :param endpoint: averager public endpoint for incoming requests
-        :param expiration_time: intent to run allreduce before this timestamp
-        :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
-          If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
-        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
-        :return: True if declared, False if declaration was rejected by DHT peers
-        :note: when leaving (i.e. is_active=False), please specify the same expiration_time as when entering the group
-        :note: setting is_active=False does *not* guarantee that others will immediately stop to query you.
-        """
-        assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_declare_averager', [],
-                        dict(group_key=group_key, endpoint=endpoint, expiration_time=expiration_time,
-                             looking_for_group=looking_for_group, future=_future)))
-        return future if return_future else future.result()
-
-    async def _declare_averager(self, node: DHTNode, *, group_key: str, endpoint: Endpoint,
-                                expiration_time: DHTExpiration, looking_for_group: bool, future: MPFuture):
-        try:
-            expiration_time = expiration_time if looking_for_group else float(nextafter(expiration_time, float('inf')))
-            # ^-- when declaring averager inactive, we increment expiration time to overwrite the pre-existing entry
-            store_ok = await node.store(
-                key=group_key, subkey=endpoint, value=looking_for_group, expiration_time=expiration_time)
-            future.set_result(store_ok)
-        except Exception as e:
-            if not future.done():
-                future.set_exception(e)
-
-    def get_averagers(self, group_key: GroupKey, *, only_active: bool = True, return_future: bool = False
-                      ) -> Union[List[Tuple[Endpoint, DHTExpiration]], MPFuture]:
-        """
-        Find and return averagers in a specified all-reduce bucket
-
-        :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
-        :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
-            if False, return all averagers under a given group_key regardless of value
-        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
-        :return: endpoints and expirations of every matching averager
-        """
-        assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_averagers', [], dict(group_key=group_key, only_active=only_active, future=_future)))
-        return future if return_future else future.result()
-
-    async def _get_averagers(self, node: DHTNode, *, group_key: str, only_active: bool, future: MPFuture):
-        try:
-            result = await node.get(group_key, latest=True)
-            if result is None:
-                logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
-                future.set_result([])
-                return
-            assert isinstance(result.value, dict), f"expected {group_key} to be a Dict[Endpoint, is_active], " \
-                                                   f"but got {result.value} of type {type(result.value)}."
-            averagers = [(endpoint, entry.expiration_time) for endpoint, entry in result.value.items()
-                         if not only_active or entry.value is True]
-            future.set_result(averagers)
-        except Exception as e:
-            if not future.done():
-                future.set_exception(e)

+ 1 - 0
hivemind/proto/averaging.proto

@@ -44,6 +44,7 @@ message MessageFromLeader {
   repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
   repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
   repeated int32 part_sizes = 5;  // a sequence of tensor parts assigned to each peer, same order as endpoints
   repeated int32 part_sizes = 5;  // a sequence of tensor parts assigned to each peer, same order as endpoints
   repeated bytes gathered = 6;  // metadata (gather) from all groupmates in the same order as their endoints
   repeated bytes gathered = 6;  // metadata (gather) from all groupmates in the same order as their endoints
+  int32 group_key_seed = 7;  // a random seed used by peers to update their group keys
 }
 }
 
 
 message AveragingData {
 message AveragingData {

+ 5 - 1
tests/benchmark_averaging.py

@@ -1,3 +1,4 @@
+import math
 import time
 import time
 import threading
 import threading
 import argparse
 import argparse
@@ -31,6 +32,8 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
                         averaging_expiration: float, request_timeout: float, round_timeout: float,
                         averaging_expiration: float, request_timeout: float, round_timeout: float,
                         hid_size: int, num_layers: int, spawn_dtime: float):
                         hid_size: int, num_layers: int, spawn_dtime: float):
     dht_root = hivemind.DHT(listen_on=f'{LOCALHOST}:*', start=True)
     dht_root = hivemind.DHT(listen_on=f'{LOCALHOST}:*', start=True)
+    num_groups = 2 ** int(round(math.log2(num_peers / target_group_size)))
+    nbits = int(round(math.log2(num_groups)))
     peer_tensors = [sample_tensors(hid_size, num_layers)
     peer_tensors = [sample_tensors(hid_size, num_layers)
                     for _ in range(num_peers)]
                     for _ in range(num_peers)]
     processes = {dht_root}
     processes = {dht_root}
@@ -39,8 +42,9 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
         dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
         dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
                            initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
                            initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
                            start=True)
                            start=True)
+        initial_bits = bin(index % num_groups)[2:].rjust(nbits, '0')
         averager = hivemind.DecentralizedAverager(
         averager = hivemind.DecentralizedAverager(
-            peer_tensors[i], dht, prefix='my_tensor', initial_group_bits='0110', listen_on=f"{LOCALHOST}:*",
+            peer_tensors[i], dht, prefix='my_tensor', initial_group_bits=initial_bits, listen_on=f"{LOCALHOST}:*",
             compression_type=runtime_pb2.CompressionType.FLOAT16, target_group_size=target_group_size,
             compression_type=runtime_pb2.CompressionType.FLOAT16, target_group_size=target_group_size,
             averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)
             averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)
         processes.update({dht, averager})
         processes.update({dht, averager})

+ 83 - 12
tests/test_averaging.py

@@ -7,31 +7,38 @@ import pytest
 import hivemind
 import hivemind
 from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
 from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.load_balancing import load_balance_peers
+from hivemind.client.averaging.key_manager import GroupKeyManager
 from hivemind.utils import Endpoint
 from hivemind.utils import Endpoint
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_getset_averagers():
-    dht = hivemind.DHT(start=True)
+@pytest.mark.asyncio
+async def test_key_manager():
+    key_manager = GroupKeyManager(hivemind.DHT(start=True), endpoint='localhvost',
+                                  prefix='test_averaging', initial_group_bits='10110',
+                                  target_group_size=2)
 
 
     t = hivemind.get_dht_time()
     t = hivemind.get_dht_time()
-    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 60)
-    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', expiration_time=t + 61)
+    key = key_manager.current_key
+    await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 60)
+    await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61)
+
+    q1 = await key_manager.get_averagers(key, only_active=True)
 
 
-    q1 = dht.get_averagers('bucket.0b10110', only_active=True)
+    await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 66)
+    q2 = await key_manager.get_averagers(key, only_active=True)
 
 
-    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 66)
-    q2 = dht.get_averagers('bucket.0b10110', only_active=True)
+    await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61, looking_for_group=False)
+    q3 = await key_manager.get_averagers(key, only_active=True)
+    q4 = await key_manager.get_averagers(key, only_active=False)
 
 
-    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', looking_for_group=False,
-                         expiration_time=t + 61)
-    q3 = dht.get_averagers('bucket.0b10110', only_active=True)
-    q4 = dht.get_averagers('bucket.0b10110', only_active=False)
+    q5 = await key_manager.get_averagers('nonexistent_key.0b0101', only_active=False)
 
 
     assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1
     assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1
     assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2
     assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2
     assert len(q3) == 1 and ('localhvost', t + 66) in q3
     assert len(q3) == 1 and ('localhvost', t + 66) in q3
     assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2
     assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2
+    assert len(q5) == 0
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -46,7 +53,7 @@ def test_allreduce_once():
     reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
     reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
 
 
     averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
     averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
-                                                prefix='mygroup', initial_group_bits='0110', listen_on='127.0.0.1:*',
+                                                prefix='mygroup', listen_on='127.0.0.1:*',
                                                 start=True)
                                                 start=True)
                  for tensors in [tensors1, tensors2, tensors3, tensors4]]
                  for tensors in [tensors1, tensors2, tensors3, tensors4]]
 
 
@@ -64,6 +71,44 @@ def test_allreduce_once():
                 assert torch.allclose(ref, our, atol=1e-6)
                 assert torch.allclose(ref, our, atol=1e-6)
 
 
 
 
+def compute_mean_std(averagers, unbiased=True):
+    results = []
+    for averager in averagers:
+        with averager.get_tensors() as tensors:
+            results.append([tensor.clone() for tensor in tensors])
+
+    results_stacked_per_tensor = list(map(torch.stack, zip(*results)))
+    means = [stack.mean(dim=0) for stack in results_stacked_per_tensor]
+    stds = [stack.std(dim=0, unbiased=unbiased) for stack in results_stacked_per_tensor]
+    return means, stds
+
+
+@pytest.mark.forked
+def test_allreduce_grid():
+    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    averagers = [hivemind.DecentralizedAverager(
+        averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
+        prefix='mygroup', initial_group_bits=bin(i // 2)[2:].rjust(2, '0'), start=True)
+        for i in range(8)]
+
+    [means0], [stds0] = compute_mean_std(averagers)
+    assert not torch.allclose(stds0, torch.zeros_like(stds0))
+
+    prev_means, prev_stds = means0, stds0
+
+    for i in range(5):
+        step_futures = [averager.step(wait=False) for averager in averagers]
+        groups = [future.result() for future in step_futures]
+        [means], [stds] = compute_mean_std(averagers)
+        assert torch.allclose(means, prev_means, atol=1e-6, rtol=0)
+        assert all(len(group) == 2 for group in groups)
+
+        if i <= 2:
+            assert torch.all(torch.le(stds, prev_stds))
+        else:
+            assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
+
+
 @pytest.mark.forked
 @pytest.mark.forked
 def test_allgather():
 def test_allgather():
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
@@ -190,3 +235,29 @@ def test_load_balancing():
         assignment = load_balance_peers(vector_size, throughputs, min_size)
         assignment = load_balance_peers(vector_size, throughputs, min_size)
         assert np.sum(assignment) == vector_size
         assert np.sum(assignment) == vector_size
         assert np.min(assignment) >= 0
         assert np.min(assignment) >= 0
+
+
+@pytest.mark.forked
+def test_too_few_peers():
+    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    averagers = [hivemind.DecentralizedAverager(
+        averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
+        averaging_expiration=1, request_timeout=0.5,
+        prefix='mygroup', initial_group_bits=bin(i)[2:].rjust(3, '0'), start=True)
+        for i in range(4)]
+    step_futures = [averager.step(wait=False) for averager in averagers]
+    for future in step_futures:
+        assert len(future.result()) == 2
+
+
+@pytest.mark.forked
+def test_overcrowded():
+    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    averagers = [hivemind.DecentralizedAverager(
+        averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
+        averaging_expiration=1, request_timeout=0.5,
+        prefix='mygroup', initial_group_bits='', start=True)
+        for _ in range(32)]
+    for t in range(5):
+        step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
+        assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1