Explorar o código

Averager: update group keys after every step, infer nbits dynamically (#141)

justheuristic %!s(int64=4) %!d(string=hai) anos
pai
achega
10917b259e

+ 6 - 1
docs/modules/client.rst

@@ -16,4 +16,9 @@
 
 .. autoclass:: RemoteMixtureOfExperts
    :members:
-   :member-order: bysource
+   :member-order: bysource
+
+.. autoclass:: DecentralizedAverager
+   :members:
+   :member-order: bysource
+   :exclude-members: get_tensors, update_tensors, rpc_join_group, rpc_aggregate_part

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.26'
+__version__ = '0.8.27'

+ 20 - 22
hivemind/client/averaging/__init__.py

@@ -6,7 +6,6 @@ import asyncio
 import contextlib
 import ctypes
 import multiprocessing as mp
-import random
 from concurrent.futures.thread import ThreadPoolExecutor
 from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 
@@ -16,7 +15,7 @@ import numpy as np
 
 import hivemind
 from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
-from hivemind.client.averaging.matchmaking import Matchmaking
+from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
 from hivemind.utils import get_logger, Endpoint, Port, MPFuture, GRPC_KEEPALIVE_OPTIONS, get_dht_time, MSGPackSerializer
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
@@ -24,7 +23,6 @@ from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 # flavour types
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 
-INITIAL_GROUP_NBITS = 3
 DataForGather = Any
 logger = get_logger(__name__)
 
@@ -43,8 +41,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     :param prefix: a shared prefix for all group keys
     :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
-    :param initial_group_bits: a string of bits ('0' and '1') that define initial group key (bucket index)
-      by default, sample a random bit sequence of length {INITIAL_GROUP_NBITS}
+    :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
     :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
       note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
     :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
@@ -56,7 +53,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
     :param throughput: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
-          If throughput == 0, averager will run in client-only mode (TODO not implemented yet!)
+          If throughput == 0, averager will rely on its groupmates to do all the averaging.
     :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
             if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
@@ -64,9 +61,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
     :param kwargs: extra parameters forwarded to grpc.aio.server
-    You can perform averaging using DecentralizedOptimizer (see below) or by manually running each step as such:
 
-    >> TODO add a working example here
+    Example:
+
+    >>> averager = DecentralizedAverager(...)
+    >>> with averager.get_tensors() as tensors:
+    >>>     # run some code, modify tensors if necessary
+    >>>     tensors[0] += 1
+    >>> # do not use tensors after the lock is released
+    >>> metadata = averager.step(gather=dict(my_batch_size=32))
+    >>> # run averaging once (in-place), gather metadata from groupmates
+    >>> with averager.get_tensors() as tensors_after_averaging:
+    >>>     pass # use the averaged tensors
     """
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
@@ -81,16 +87,13 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                  listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
-        assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), "throughput must be a" \
-                                                                                                " nonnegative float32"
+        assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
+            "throughput must be a non-negative float32"
         if not listen:
             raise NotImplementedError("Client-only averaging is not implemented yet.")
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
-        if initial_group_bits is None:
-            initial_group_bits = ''.join(random.choices('01', k=INITIAL_GROUP_NBITS))
-            logger.debug(f"Initializing with random {INITIAL_GROUP_NBITS}-bit group index: {initial_group_bits}")
-        assert len(initial_group_bits) >= INITIAL_GROUP_NBITS and all(bit in '01' for bit in initial_group_bits)
+        assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
 
         super().__init__()
         self.dht = dht
@@ -178,7 +181,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         else:
             logger.warning("DHT shutdown has no effect: the process is not alive")
 
-    def step(self, allow_retries: bool = True, gather: Optional[DataForGather] = None, timeout: Optional[float] = None,
+    def step(self, gather: Optional[DataForGather] = None, allow_retries: bool = True, timeout: Optional[float] = None,
              wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
@@ -219,7 +222,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
                 future.set_result(gathered_data_by_peer)
 
-            except AllreduceException:
+            except (AllreduceException, MatchmakingException):
                 time_elapsed = get_dht_time() - start_time
                 if not allow_retries or (timeout is not None and timeout < time_elapsed):
                     future.set_result(None)
@@ -248,12 +251,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """
         A contextmanager that gives user access to averaged tensors.
         It is guaranteed that the averager will not modify tensors while this context is active.
-
-        Example:
-              >>> with averager.get_tensors() as tensors:
-              >>>     update_model(tensors)
-              >>>     tensors[0] += 1
-              >>> # do not use tensors after the lock is acquired
+        Please do not modify the yielded tensors in-place after the context is released.
         """
         with self.lock_averaged_tensors:
             yield self._averaged_tensors

+ 2 - 1
hivemind/client/averaging/allreduce.py

@@ -123,12 +123,13 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
 
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
                  ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
-                 chunk_size_bytes: int, part_sizes: Tuple[int, ...], gathered: Sequence[Any] = (),
+                 chunk_size_bytes: int, part_sizes: Tuple[int, ...], group_key_seed: int, gathered: Sequence[Any] = (),
                  return_deltas: bool = False):
         super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
                          ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
         self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
         self.averaged_part_stream: asyncio.Future[Tuple[runtime_pb2.Tensor, ...]] = asyncio.Future()
+        self.group_key_seed = group_key_seed
 
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)

+ 142 - 0
hivemind/client/averaging/key_manager.py

@@ -0,0 +1,142 @@
+import asyncio
+import re
+import random
+from typing import Optional, List, Tuple
+
+import numpy as np
+
+from hivemind.dht import DHT
+from hivemind.client.averaging.allreduce import AllReduceRunner
+from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration
+
+GroupKey = str
+GROUP_PATTERN = re.compile('^(([^.])+)[.]0b[01]*$')  # e.g. bert_exp4_averaging.0b01001101
+logger = get_logger(__name__)
+
+
+def is_valid_group(maybe_group: str) -> bool:
+    """ A group identifier must contain group type, followed by one or more .-separated indices, and any ?metadata"""
+    return bool(GROUP_PATTERN.fullmatch(maybe_group))
+
+
+class GroupKeyManager:
+    """
+    Utility class that declares and fetches averaging-related keys using a DHT
+    """
+    RESERVED_KEY_FOR_NBITS = '::NBITS'
+
+    def __init__(self, dht: DHT, endpoint: Endpoint, prefix: str, initial_group_bits: Optional[str],
+                 target_group_size: int, insufficient_size: Optional[int] = None, excessive_size: Optional[int] = None,
+                 nbits_expiration: float = 60):
+        assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
+        if initial_group_bits is None:
+            search_result = dht.get(f"{prefix}.0b", latest=True)
+            initial_group_bits = self.get_suggested_nbits(search_result) or ''
+        self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
+        self.target_group_size = target_group_size
+        self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
+        self.excessive_size = excessive_size or target_group_size * 3
+        self.nbits_expiration = nbits_expiration
+        self.suggested_nbits: Optional[int] = None
+
+    @property
+    def current_key(self) -> GroupKey:
+        return f"{self.prefix}.0b{self.group_bits}"
+
+    async def declare_averager(self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float,
+                               looking_for_group: bool = True) -> bool:
+        """
+        Add (or remove) the averager to a given allreduce bucket
+
+        :param group_key: allreduce group key, e.g. my_averager.0b011011101
+        :param endpoint: averager public endpoint for incoming requests
+        :param expiration_time: intent to run allreduce before this timestamp
+        :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
+          If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
+        :return: True if declared, False if declaration was rejected by DHT peers
+        :note: when leaving (i.e. is_active=False), please specify the same expiration_time as when entering the group
+        :note: setting is_active=False does *not* guarantee that others will immediately stop to query you.
+        """
+        expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float('inf')))
+        return await self.dht.store(key=group_key, subkey=endpoint, value=looking_for_group,
+                                    expiration_time=expiration_time, return_future=True)
+
+    async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tuple[Endpoint, DHTExpiration]]:
+        """
+        Find and return averagers that were declared with a given all-reduce key
+
+        :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
+        :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
+            if False, return all averagers under a given group_key regardless of value
+        :return: endpoints and expirations of every matching averager
+        """
+        assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
+        result = await self.dht.get(group_key, latest=True, return_future=True)
+        if result is None or not isinstance(result.value, dict):
+            logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
+            return []
+        averagers = [(key, entry.expiration_time) for key, entry in result.value.items()
+                     if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)]
+        num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
+
+        suggested_nbits = self.get_suggested_nbits(result)
+        if suggested_nbits is not None and suggested_nbits != self.suggested_nbits:
+            self.suggested_nbits = suggested_nbits
+            logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
+        elif num_active_averagers >= self.excessive_size:
+            self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
+            logger.warning(f"{self.endpoint} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
+        return averagers
+
+    async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
+        """ notify other peers that they can run averaging at this depth """
+        return await self.dht.store(key=group_key, subkey=self.RESERVED_KEY_FOR_NBITS, value=nbits,
+                                    expiration_time=expiration_time, return_future=True)
+
+    @classmethod
+    def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
+        if isinstance(search_result, ValueWithExpiration) and cls.RESERVED_KEY_FOR_NBITS in search_result.value \
+                and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int):
+            return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
+        else:
+            return None
+
+    async def update_key_on_group_assembled(self, allreduce_group: AllReduceRunner, is_leader: bool = True):
+        """ this function is triggered every time an averager finds an allreduce group """
+        rng = random.Random(allreduce_group.group_key_seed)
+        index = allreduce_group.ordered_group_endpoints.index(self.endpoint)
+        generalized_index = rng.sample(range(self.target_group_size), allreduce_group.group_size)[index]
+        nbits = int(np.ceil(np.log2(self.target_group_size)))
+        new_bits = bin(generalized_index)[2:].rjust(nbits, '0')
+        self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits):]
+        logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
+
+        if is_leader and self.insufficient_size < allreduce_group.group_size < self.excessive_size:
+            asyncio.create_task(self.notify_stragglers_on_success())
+        if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
+            num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
+            self.group_bits = ''.join((random.choice('01') for _ in range(num_extra_bits))) + self.group_bits
+            self.group_bits = self.group_bits[-self.suggested_nbits:]
+        self.suggested_nbits = None
+
+    async def update_key_on_not_enough_peers(self):
+        """ this function is triggered whenever averager fails to assemble group within timeout """
+        new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
+        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:]
+        if self.group_bits != prev_nbits:
+            logger.warning(f'{self.endpoint} - switching to {len(self.group_bits)}-bit keys')
+        self.suggested_nbits = None
+
+    async def notify_stragglers_on_success(self):
+        """ Find averagers that have fewer nbits and redirect them to your current nbits """
+        for nbits in reversed(range(1, len(self.group_bits) - 1)):
+            preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
+            preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
+
+            if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
+                await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
+                break
+
+        root_data = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True)
+        if root_data is None or self.RESERVED_KEY_FOR_NBITS not in root_data.value:
+            await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)

+ 61 - 54
hivemind/client/averaging/matchmaking.py

@@ -12,14 +12,15 @@ import asyncio
 import grpc
 import torch
 
-import hivemind
 from hivemind.client.averaging.allreduce import AllReduceRunner
 from hivemind.client.averaging.load_balancing import load_balance_peers
-from hivemind.dht import DHTID, DHTExpiration, get_dht_time, GroupKey
-from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, TimedStorage
+from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
+from hivemind.dht import DHT, DHTID, DHTExpiration, get_dht_time
+from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, timed_storage, TimedStorage
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc
 from hivemind.utils.grpc import ChannelCache
 
+
 logger = get_logger(__name__)
 
 
@@ -36,7 +37,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
       Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
     """
 
-    def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *,
+    def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *,
                  prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
                  averaging_expiration: float = 15, request_timeout: float, throughput: Optional[float] = None,
                  min_vector_size: int, **allreduce_kwargs):
@@ -46,12 +47,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                            "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring.")
 
         super().__init__()
-        self.dht, self.endpoint, self.averaged_tensors = dht, endpoint, tuple(averaged_tensors)
-        self.prefix, self.group_bits = prefix, initial_group_bits
+        self.endpoint, self.averaged_tensors = endpoint, tuple(averaged_tensors)
+        self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
-        self.throughput, self.min_vector_size = throughput, min_vector_size
-        self.allreduce_kwargs = allreduce_kwargs
+        self.throughput, self.min_vector_size, self.allreduce_kwargs = throughput, min_vector_size, allreduce_kwargs
         self.schema_hash = compute_schema_hash(self.averaged_tensors)
         self.total_size = sum(tensor.numel() for tensor in self.averaged_tensors)
 
@@ -63,17 +63,13 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
         self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(endpoint, dht, averaging_expiration, target_group_size)
-        self.data_for_gather: bytes = None
+        self.potential_leaders = PotentialLeaders(endpoint, averaging_expiration, target_group_size)
+        self.data_for_gather: Optional[bytes] = None
 
     @property
     def is_looking_for_group(self):
         return self.lock_looking_for_group.locked()
 
-    @property
-    def current_group_key(self) -> GroupKey:
-        return f"{self.prefix}.0b{self.group_bits}"
-
     def __repr__(self):
         lfg_status = "looking for group," if self.is_looking_for_group else "not looking for group,"
         if self.is_looking_for_group:
@@ -83,12 +79,12 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 lfg_status += f" leading {len(self.current_followers)} followers,"
         schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
         return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
-               f" current key = {self.current_group_key})"
+               f" current key = {self.group_key_manager.current_key})"
 
     async def look_for_group(self, *, data_for_gather: bytes = b'', timeout: Optional[float] = None
                              ) -> Optional[AllReduceRunner]:
         """
-        :param gather: optionally send this data to all peers in the next group and gather it from every groupmate
+        :param data_for_gather: optionally send this data to all peers in the next group and gather it from groupmates
         :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
@@ -127,10 +123,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
         """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
-        async with self.potential_leaders.begin_search(self.current_group_key, timeout):
-            # TODO update group_bits on success! reduce number of bits on not enough peers.
-            # TODO after allreduce finishes, we may need to ask leader to notify lower keys about this
-            # (so as to fix possible network partitioning if some peers operate on a much smaller nbits)
+        async with self.potential_leaders.begin_search(self.group_key_manager, timeout):
             while True:
                 try:
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
@@ -148,7 +141,6 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                             return await self.leader_assemble_group()
                         elif len(self.current_followers) > 0:
                             await self.leader_disband_group()
-                            # TODO maybe adjust grid size
                         continue
                 except Exception as e:
                     if not self.assembled_group.done():
@@ -262,8 +254,8 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             allreduce_group = self.assembled_group.result()
             yield averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
-                ordered_group_endpoints=allreduce_group.ordered_group_endpoints,
-                part_sizes=allreduce_group.part_sizes, gathered=allreduce_group.gathered)
+                ordered_group_endpoints=allreduce_group.ordered_group_endpoints, part_sizes=allreduce_group.part_sizes,
+                gathered=allreduce_group.gathered, group_key_seed=allreduce_group.group_key_seed)
 
         except Exception as e:
             logger.exception(e)
@@ -319,11 +311,13 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 gathered.append(follower_info.gather if follower_info.gather else None)
 
         part_sizes = load_balance_peers(self.total_size, throughputs, self.min_vector_size)
+        group_key_seed = random.randint(- 2 ** 31, 2 ** 31 - 1)
 
         logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
                                           ordered_group_endpoints=ordered_group_endpoints, part_sizes=part_sizes,
-                                          gathered=gathered, **self.allreduce_kwargs)
+                                          gathered=gathered, group_key_seed=group_key_seed, **self.allreduce_kwargs)
+        await self.group_key_manager.update_key_on_group_assembled(allreduce_group, is_leader=True)
         self.assembled_group.set_result(allreduce_group)
         return allreduce_group
 
@@ -340,7 +334,9 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
                                           ordered_group_endpoints=tuple(ordered_group_endpoints),
-                                          part_sizes=tuple(part_sizes), gathered=msg.gathered, **self.allreduce_kwargs)
+                                          part_sizes=tuple(part_sizes), gathered=msg.gathered,
+                                          group_key_seed=int(msg.group_key_seed), **self.allreduce_kwargs)
+        await self.group_key_manager.update_key_on_group_assembled(allreduce_group)
         self.assembled_group.set_result(allreduce_group)
         return allreduce_group
 
@@ -353,9 +349,8 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 class PotentialLeaders:
     """ An utility class that searches for averagers that could become our leaders """
 
-    def __init__(self, endpoint: Endpoint, dht: hivemind.DHT, averaging_expiration: DHTExpiration,
-                 target_group_size: Optional[int]):
-        self.endpoint, self.dht, self.averaging_expiration = endpoint, dht, averaging_expiration
+    def __init__(self, endpoint: Endpoint, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
+        self.endpoint, self.averaging_expiration = endpoint, averaging_expiration
         self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
@@ -367,12 +362,12 @@ class PotentialLeaders:
         self.search_end_time = float('inf')
 
     @contextlib.asynccontextmanager
-    async def begin_search(self, group_key: GroupKey, timeout: Optional[float]):
+    async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float]):
         async with self.lock_search:
             self.running.set()
             self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
-            update_queue_task = asyncio.create_task(self._update_queue_periodically(group_key))
-            declare_averager_task = asyncio.create_task(self._declare_averager_periodically(group_key))
+            update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
+            declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
             try:
                 yield self
             finally:
@@ -429,38 +424,46 @@ class PotentialLeaders:
         else:
             return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
 
-    async def _update_queue_periodically(self, group_key: GroupKey):
-        DISCREPANCY = hivemind.utils.timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
-        while get_dht_time() < self.search_end_time:
-            new_peers = await self.dht.get_averagers(group_key, only_active=True, return_future=True)
-            self.max_assured_time = max(self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY)
-
-            self.leader_queue.clear()
-            for peer, peer_expiration_time in new_peers:
-                if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
-                    continue
-                self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
-                self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
+    async def _update_queue_periodically(self, key_manager: GroupKeyManager):
+        try:
+            DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
+            while get_dht_time() < self.search_end_time:
+                new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
+                self.max_assured_time = max(self.max_assured_time,
+                                            get_dht_time() + self.averaging_expiration - DISCREPANCY)
+
+                self.leader_queue.clear()
+                for peer, peer_expiration_time in new_peers:
+                    if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
+                        continue
+                    self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
+                    self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
 
-            self.update_finished.set()
+                self.update_finished.set()
 
-            await asyncio.wait(
-                {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
-                timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None)
-            self.update_triggered.clear()
+                await asyncio.wait(
+                    {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
+                    timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None)
+                self.update_triggered.clear()
+        except Exception as e:
+            logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
+            raise
 
-    async def _declare_averager_periodically(self, group_key: GroupKey):
+    async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
         async with self.lock_declare:
             try:
                 while True:
                     await self.running.wait()
 
                     new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
-                    self.declared_group_key, self.declared_expiration_time = group_key, new_expiration_time
+                    self.declared_group_key = group_key = key_manager.current_key
+                    self.declared_expiration_time = new_expiration_time
                     self.declared_expiration.set()
-                    await self.dht.declare_averager(group_key, self.endpoint, new_expiration_time,
-                                                    looking_for_group=True, return_future=True)
+                    await key_manager.declare_averager(group_key, self.endpoint, expiration_time=new_expiration_time)
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
+                    if self.running.is_set() and len(self.leader_queue) == 0:
+                        await key_manager.update_key_on_not_enough_peers()
+
             except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
                 logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
             finally:
@@ -468,8 +471,8 @@ class PotentialLeaders:
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
                     self.declared_group_key, self.declared_expiration_time = None, float('inf')
                     self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float('-inf')
-                    await self.dht.declare_averager(prev_declared_key, self.endpoint, prev_expiration_time,
-                                                    looking_for_group=False, return_future=True)
+                    await key_manager.declare_averager(prev_declared_key, self.endpoint, prev_expiration_time,
+                                                       looking_for_group=False)
 
 
 def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
@@ -478,3 +481,7 @@ def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
                      for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
                     for tensor in tensors]
     return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()
+
+
+class MatchmakingException(Exception):
+    """ An internal exception that marks undesired edge cases during averaging """

+ 51 - 77
hivemind/dht/__init__.py

@@ -12,6 +12,7 @@ The code is organized as follows:
 - [1] Maymounkov P., Mazieres D. (2002) Kademlia: A Peer-to-Peer Information System Based on the XOR Metric.
 - [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :)
 """
+from __future__ import annotations
 import asyncio
 import ctypes
 import heapq
@@ -21,12 +22,11 @@ from collections import deque
 from concurrent.futures import ThreadPoolExecutor
 from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
 
-from numpy import nextafter
 
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
-from hivemind.dht.routing import get_dht_time, DHTValue
-from hivemind.utils import MPFuture, Endpoint, Hostname, get_logger, switch_to_uvloop, strip_port
+from hivemind.dht.routing import get_dht_time, DHTValue, DHTKey, Subkey
+from hivemind.utils import MPFuture, Endpoint, Hostname, get_logger, switch_to_uvloop, strip_port, ValueWithExpiration
 
 logger = get_logger(__name__)
 
@@ -37,8 +37,6 @@ FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to spe
 UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
 PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
 #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
-GroupKey = str
-GROUP_PATTERN = re.compile('^(([^.])+)[.]0b[01]+$')  # e.g. bert_exp4_averaging.0b01001101
 
 
 def is_valid_uid(maybe_uid: str) -> bool:
@@ -50,12 +48,6 @@ def is_valid_prefix(maybe_prefix: str) -> bool:
     """ An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period """
     return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
 
-
-def is_valid_group(maybe_group: str) -> bool:
-    """ A group identifier must contain group type, followed by one or more .-separated indices, and any ?metadata"""
-    return bool(GROUP_PATTERN.fullmatch(maybe_group))
-
-
 def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
     """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
@@ -180,6 +172,54 @@ class DHT(mp.Process):
     def port(self) -> Optional[int]:
         return self._port.value if self._port.value != 0 else None
 
+    def get(self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
+            ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
+        """
+        Search for a key across DHT and return either first or latest entry (if found).
+        :param key: same key as in node.store(...)
+        :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :param kwargs: parameters forwarded to DHTNode.get_many_by_id
+        :returns: (value, expiration time); if value was not found, returns None
+        """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_get', [], dict(key=key, latest=latest, future=_future, **kwargs)))
+        return future if return_future else future.result()
+
+    async def _get(self, node: DHTNode, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
+        try:
+            result = await node.get(key, latest=latest, **kwargs)
+            if not future.done():
+                future.set_result(result)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+            raise
+
+    def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
+              subkey: Optional[Subkey] = None, return_future: bool = False, **kwargs) -> Union[bool, MPFuture]:
+        """
+        Find num_replicas best nodes to store (key, value) and store it there until expiration time.
+        :note: store is a simplified interface to store_many, all kwargs are be forwarded there
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: True if store succeeds, False if it fails (due to no response or newer value)
+        """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey,
+                                           future=_future, **kwargs)))
+        return future if return_future else future.result()
+
+    async def _store(self, node: DHTNode, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
+                     subkey: Optional[Subkey], future: MPFuture, **kwargs):
+        try:
+            result = await node.store(key, value, expiration_time, subkey=subkey, **kwargs)
+            if not future.done():
+                future.set_result(result)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+            raise
+
     def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
         """
         Get this machine's visible address by requesting other peers or using pre-specified network addresses.
@@ -519,69 +559,3 @@ class DHT(mp.Process):
         if future is not None:
             future.set_result(best_experts_batch)
         return best_experts_batch
-
-    def declare_averager(self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float, *,
-                         looking_for_group: bool = True, return_future: bool = False) -> Union[bool, MPFuture]:
-        """
-        Add (or remove) the averager to a given allreduce bucket
-
-        :param group_key: allreduce group key, e.g. my_averager.0b011011101
-        :param endpoint: averager public endpoint for incoming requests
-        :param expiration_time: intent to run allreduce before this timestamp
-        :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
-          If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
-        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
-        :return: True if declared, False if declaration was rejected by DHT peers
-        :note: when leaving (i.e. is_active=False), please specify the same expiration_time as when entering the group
-        :note: setting is_active=False does *not* guarantee that others will immediately stop to query you.
-        """
-        assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_declare_averager', [],
-                        dict(group_key=group_key, endpoint=endpoint, expiration_time=expiration_time,
-                             looking_for_group=looking_for_group, future=_future)))
-        return future if return_future else future.result()
-
-    async def _declare_averager(self, node: DHTNode, *, group_key: str, endpoint: Endpoint,
-                                expiration_time: DHTExpiration, looking_for_group: bool, future: MPFuture):
-        try:
-            expiration_time = expiration_time if looking_for_group else float(nextafter(expiration_time, float('inf')))
-            # ^-- when declaring averager inactive, we increment expiration time to overwrite the pre-existing entry
-            store_ok = await node.store(
-                key=group_key, subkey=endpoint, value=looking_for_group, expiration_time=expiration_time)
-            future.set_result(store_ok)
-        except Exception as e:
-            if not future.done():
-                future.set_exception(e)
-
-    def get_averagers(self, group_key: GroupKey, *, only_active: bool = True, return_future: bool = False
-                      ) -> Union[List[Tuple[Endpoint, DHTExpiration]], MPFuture]:
-        """
-        Find and return averagers in a specified all-reduce bucket
-
-        :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
-        :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
-            if False, return all averagers under a given group_key regardless of value
-        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
-        :return: endpoints and expirations of every matching averager
-        """
-        assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_averagers', [], dict(group_key=group_key, only_active=only_active, future=_future)))
-        return future if return_future else future.result()
-
-    async def _get_averagers(self, node: DHTNode, *, group_key: str, only_active: bool, future: MPFuture):
-        try:
-            result = await node.get(group_key, latest=True)
-            if result is None:
-                logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
-                future.set_result([])
-                return
-            assert isinstance(result.value, dict), f"expected {group_key} to be a Dict[Endpoint, is_active], " \
-                                                   f"but got {result.value} of type {type(result.value)}."
-            averagers = [(endpoint, entry.expiration_time) for endpoint, entry in result.value.items()
-                         if not only_active or entry.value is True]
-            future.set_result(averagers)
-        except Exception as e:
-            if not future.done():
-                future.set_exception(e)

+ 1 - 0
hivemind/proto/averaging.proto

@@ -44,6 +44,7 @@ message MessageFromLeader {
   repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
   repeated int32 part_sizes = 5;  // a sequence of tensor parts assigned to each peer, same order as endpoints
   repeated bytes gathered = 6;  // metadata (gather) from all groupmates in the same order as their endoints
+  int32 group_key_seed = 7;  // a random seed used by peers to update their group keys
 }
 
 message AveragingData {

+ 5 - 1
tests/benchmark_averaging.py

@@ -1,3 +1,4 @@
+import math
 import time
 import threading
 import argparse
@@ -31,6 +32,8 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
                         averaging_expiration: float, request_timeout: float, round_timeout: float,
                         hid_size: int, num_layers: int, spawn_dtime: float):
     dht_root = hivemind.DHT(listen_on=f'{LOCALHOST}:*', start=True)
+    num_groups = 2 ** int(round(math.log2(num_peers / target_group_size)))
+    nbits = int(round(math.log2(num_groups)))
     peer_tensors = [sample_tensors(hid_size, num_layers)
                     for _ in range(num_peers)]
     processes = {dht_root}
@@ -39,8 +42,9 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
         dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
                            initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
                            start=True)
+        initial_bits = bin(index % num_groups)[2:].rjust(nbits, '0')
         averager = hivemind.DecentralizedAverager(
-            peer_tensors[i], dht, prefix='my_tensor', initial_group_bits='0110', listen_on=f"{LOCALHOST}:*",
+            peer_tensors[i], dht, prefix='my_tensor', initial_group_bits=initial_bits, listen_on=f"{LOCALHOST}:*",
             compression_type=runtime_pb2.CompressionType.FLOAT16, target_group_size=target_group_size,
             averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)
         processes.update({dht, averager})

+ 83 - 12
tests/test_averaging.py

@@ -7,31 +7,38 @@ import pytest
 import hivemind
 from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
 from hivemind.client.averaging.load_balancing import load_balance_peers
+from hivemind.client.averaging.key_manager import GroupKeyManager
 from hivemind.utils import Endpoint
 
 
 @pytest.mark.forked
-def test_getset_averagers():
-    dht = hivemind.DHT(start=True)
+@pytest.mark.asyncio
+async def test_key_manager():
+    key_manager = GroupKeyManager(hivemind.DHT(start=True), endpoint='localhvost',
+                                  prefix='test_averaging', initial_group_bits='10110',
+                                  target_group_size=2)
 
     t = hivemind.get_dht_time()
-    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 60)
-    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', expiration_time=t + 61)
+    key = key_manager.current_key
+    await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 60)
+    await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61)
+
+    q1 = await key_manager.get_averagers(key, only_active=True)
 
-    q1 = dht.get_averagers('bucket.0b10110', only_active=True)
+    await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 66)
+    q2 = await key_manager.get_averagers(key, only_active=True)
 
-    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 66)
-    q2 = dht.get_averagers('bucket.0b10110', only_active=True)
+    await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61, looking_for_group=False)
+    q3 = await key_manager.get_averagers(key, only_active=True)
+    q4 = await key_manager.get_averagers(key, only_active=False)
 
-    dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', looking_for_group=False,
-                         expiration_time=t + 61)
-    q3 = dht.get_averagers('bucket.0b10110', only_active=True)
-    q4 = dht.get_averagers('bucket.0b10110', only_active=False)
+    q5 = await key_manager.get_averagers('nonexistent_key.0b0101', only_active=False)
 
     assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1
     assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2
     assert len(q3) == 1 and ('localhvost', t + 66) in q3
     assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2
+    assert len(q5) == 0
 
 
 @pytest.mark.forked
@@ -46,7 +53,7 @@ def test_allreduce_once():
     reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
 
     averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
-                                                prefix='mygroup', initial_group_bits='0110', listen_on='127.0.0.1:*',
+                                                prefix='mygroup', listen_on='127.0.0.1:*',
                                                 start=True)
                  for tensors in [tensors1, tensors2, tensors3, tensors4]]
 
@@ -64,6 +71,44 @@ def test_allreduce_once():
                 assert torch.allclose(ref, our, atol=1e-6)
 
 
+def compute_mean_std(averagers, unbiased=True):
+    results = []
+    for averager in averagers:
+        with averager.get_tensors() as tensors:
+            results.append([tensor.clone() for tensor in tensors])
+
+    results_stacked_per_tensor = list(map(torch.stack, zip(*results)))
+    means = [stack.mean(dim=0) for stack in results_stacked_per_tensor]
+    stds = [stack.std(dim=0, unbiased=unbiased) for stack in results_stacked_per_tensor]
+    return means, stds
+
+
+@pytest.mark.forked
+def test_allreduce_grid():
+    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    averagers = [hivemind.DecentralizedAverager(
+        averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
+        prefix='mygroup', initial_group_bits=bin(i // 2)[2:].rjust(2, '0'), start=True)
+        for i in range(8)]
+
+    [means0], [stds0] = compute_mean_std(averagers)
+    assert not torch.allclose(stds0, torch.zeros_like(stds0))
+
+    prev_means, prev_stds = means0, stds0
+
+    for i in range(5):
+        step_futures = [averager.step(wait=False) for averager in averagers]
+        groups = [future.result() for future in step_futures]
+        [means], [stds] = compute_mean_std(averagers)
+        assert torch.allclose(means, prev_means, atol=1e-6, rtol=0)
+        assert all(len(group) == 2 for group in groups)
+
+        if i <= 2:
+            assert torch.all(torch.le(stds, prev_stds))
+        else:
+            assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
+
+
 @pytest.mark.forked
 def test_allgather():
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
@@ -190,3 +235,29 @@ def test_load_balancing():
         assignment = load_balance_peers(vector_size, throughputs, min_size)
         assert np.sum(assignment) == vector_size
         assert np.min(assignment) >= 0
+
+
+@pytest.mark.forked
+def test_too_few_peers():
+    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    averagers = [hivemind.DecentralizedAverager(
+        averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
+        averaging_expiration=1, request_timeout=0.5,
+        prefix='mygroup', initial_group_bits=bin(i)[2:].rjust(3, '0'), start=True)
+        for i in range(4)]
+    step_futures = [averager.step(wait=False) for averager in averagers]
+    for future in step_futures:
+        assert len(future.result()) == 2
+
+
+@pytest.mark.forked
+def test_overcrowded():
+    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    averagers = [hivemind.DecentralizedAverager(
+        averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
+        averaging_expiration=1, request_timeout=0.5,
+        prefix='mygroup', initial_group_bits='', start=True)
+        for _ in range(32)]
+    for t in range(5):
+        step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
+        assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1