Răsfoiți Sursa

Address averaging corner cases, add benchmark_averaging.py, chunk averaged tensors, fix DHTNode get (#134)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 ani în urmă
părinte
comite
e159605143

+ 0 - 6
.circleci/config.yml

@@ -21,12 +21,6 @@ jobs:
       - run:
       - run:
           command: pytest ./tests
           command: pytest ./tests
           name: tests
           name: tests
-      - run:
-          command: python tests/benchmark_throughput.py --preset minimalistic
-          name: benchmark_throughput
-      - run:
-          command: python tests/benchmark_dht.py
-          name: benchmark_dht
       - run:
       - run:
           command: codecov
           command: codecov
           name: codecov
           name: codecov

+ 114 - 55
hivemind/client/averaging/__init__.py

@@ -2,22 +2,23 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
-import random
+import asyncio
+import contextlib
 import ctypes
 import ctypes
-from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
-from concurrent.futures.thread import ThreadPoolExecutor
 import multiprocessing as mp
 import multiprocessing as mp
-import asyncio
+import random
+from concurrent.futures.thread import ThreadPoolExecutor
+from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 
 
-import torch
-import uvloop
 import grpc
 import grpc
+import torch
 
 
 import hivemind
 import hivemind
 from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
 from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
 from hivemind.client.averaging.matchmaking import Matchmaking
 from hivemind.client.averaging.matchmaking import Matchmaking
-from hivemind.utils import get_logger, Endpoint, Port, MPFuture, replace_port, GRPC_KEEPALIVE_OPTIONS
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
+from hivemind.utils import get_logger, Endpoint, Port, MPFuture, replace_port, GRPC_KEEPALIVE_OPTIONS, get_dht_time
+from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 
 
 # flavour types
 # flavour types
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
@@ -46,7 +47,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
       note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
       note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
     :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
     :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
     :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
     :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
-
+    :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
+    :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
+    :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
     :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
     :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
             if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
             if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
@@ -62,13 +65,14 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     _pending_group_assembled: asyncio.Event
     _pending_group_assembled: asyncio.Event
 
 
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
-                 prefix: str, target_group_size: int, min_group_size: int = 1, initial_group_bits: Optional[str] = None,
+                 prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
                  averaging_expiration: float = 15, allreduce_timeout: Optional[float] = None,
                  averaging_expiration: float = 15, allreduce_timeout: Optional[float] = None,
+                 request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
-                 listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1,
+                 listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
-        assert '.' not in prefix, "group prefix must be a string without ."
-        if is_power_of_two(target_group_size):
+        assert '.' not in prefix, "group prefix must be a string without trailing '.'"
+        if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
             logger.warning("It is recommended to set target_group_size to a power of 2.")
         if initial_group_bits is None:
         if initial_group_bits is None:
             initial_group_bits = ''.join(random.choices('01', k=INITIAL_GROUP_NBITS))
             initial_group_bits = ''.join(random.choices('01', k=INITIAL_GROUP_NBITS))
@@ -79,16 +83,19 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self.dht = dht
         self.dht = dht
         self.listen_on, self.receiver_threads, self.kwargs = listen_on, receiver_threads, kwargs
         self.listen_on, self.receiver_threads, self.kwargs = listen_on, receiver_threads, kwargs
         self.channel_options = channel_options
         self.channel_options = channel_options
-        self.averaged_tensors = tuple(averaged_tensors)
-        # TODO use mp.Lock to prevent someone from modifying tensors before we copy them! maybe.
-        for tensor in self.averaged_tensors:
+        self.daemon = daemon
+
+        self._averaged_tensors = tuple(averaged_tensors)
+        self.lock_averaged_tensors = mp.Lock()
+        for tensor in self._averaged_tensors:
             assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
             assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
             tensor.share_memory_()
             tensor.share_memory_()
 
 
-        self.matchmaking_kwargs = dict(prefix=prefix, initial_group_bits=initial_group_bits,
-                                       target_group_size=target_group_size, min_group_size=min_group_size,
-                                       averaging_expiration=averaging_expiration)
-        self.allreduce_timeout, self.compression_type = allreduce_timeout, compression_type
+        self.matchmaking_kwargs = dict(
+            prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
+            min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout,
+            chunk_size_bytes=chunk_size_bytes, compression_type=compression_type)
+        self.allreduce_timeout = allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
 
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
@@ -115,13 +122,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
     def run(self):
     def run(self):
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
-        if asyncio.get_event_loop().is_running():
-            asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
-
-        uvloop.install()
-        loop = asyncio.new_event_loop()
-        asyncio.set_event_loop(loop)
-
+        loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
         # initialize asyncio synchronization primitives in this event loop
         pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
         pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
 
 
@@ -132,7 +133,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             found_port = server.add_insecure_port(self.listen_on)
             found_port = server.add_insecure_port(self.listen_on)
             assert found_port != 0, f"Failed to listen to {self.listen_on}"
             assert found_port != 0, f"Failed to listen to {self.listen_on}"
             self._port.value = found_port
             self._port.value = found_port
-            self._matchmaking = Matchmaking(self.endpoint, self.averaged_tensors, self.dht, **self.matchmaking_kwargs)
+            self._matchmaking = Matchmaking(self.endpoint, self._averaged_tensors, self.dht, **self.matchmaking_kwargs,
+                                            return_deltas=True)  # note: we need deltas to make allreduce lock-free
             self._pending_group_assembled = asyncio.Event()
             self._pending_group_assembled = asyncio.Event()
             self._pending_group_assembled.set()
             self._pending_group_assembled.set()
             await server.start()
             await server.start()
@@ -161,37 +163,88 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         else:
         else:
             logger.warning("DHT shutdown has no effect: the process is not alive")
             logger.warning("DHT shutdown has no effect: the process is not alive")
 
 
-    def step(self, timeout: Optional[float] = None, return_future=False) -> Union[Sequence[torch.Tensor], MPFuture]:
+    def step(self, allow_retries: bool = True, timeout: Optional[float] = None, wait=True
+             ) -> Union[bool, MPFuture]:
         """
         """
-        Set up the averager to look for a group and run one round of averaging, then return the averaged tensors
-
+        Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
+        :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
+          within the specified timeout
         :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
         :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
         """
         """
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_step', [], dict(future=_future, timeout=timeout)))
-        return future if return_future else future.result()
+        self.pipe.send(('_step', [], dict(future=_future, allow_retries=allow_retries, timeout=timeout)))
+        return future.result() if wait else future
+
+    async def _step(self, *, future: MPFuture, allow_retries: bool, timeout: Optional[float]):
+        loop = asyncio.get_event_loop()
+        start_time = get_dht_time()
 
 
-    async def _step(self, *, future: MPFuture, timeout: Optional[float]):
+        try_averaging = True
         group_id = None
         group_id = None
-        try:
-            self._pending_group_assembled.clear()
-            allreduce_group = await self._matchmaking.look_for_group(timeout=timeout)
-            group_id = allreduce_group.group_id
-            if allreduce_group is not None:
+
+        while try_averaging:
+            try:
+                self._pending_group_assembled.clear()
+                allreduce_group = await self._matchmaking.look_for_group(timeout=timeout)
+                if allreduce_group is None:
+                    raise AllreduceException("Averaging step failed: could not find a group.")
+
+                group_id = allreduce_group.group_id
                 self._running_groups[group_id] = allreduce_group
                 self._running_groups[group_id] = allreduce_group
                 self._pending_group_assembled.set()
                 self._pending_group_assembled.set()
-                future.set_result(await asyncio.wait_for(allreduce_group.run(), self.allreduce_timeout))
-            else:
-                raise AllreduceException(f"{self} - group_allreduce failed, unable to find a group")
-
-        except Exception as e:
-            future.set_exception(e)
-            raise
-        finally:
-            self._pending_group_assembled.set()
-            if group_id is not None:
+                averaging_deltas = await asyncio.wait_for(allreduce_group.run(), self.allreduce_timeout)
+                update_ok = await loop.run_in_executor(None, lambda: self.update_tensors(averaging_deltas, add=True))
+
+                # averaging is finished, exit the loop
+                future.set_result(update_ok)
+                try_averaging = False
+
+            except AllreduceException:
+                time_elapsed = get_dht_time() - start_time
+                if not allow_retries or (timeout is not None and timeout < time_elapsed):
+                    future.set_result(False)
+                    try_averaging = False
+
+            except Exception as e:
+                future.set_exception(e)
+                raise
+            finally:
                 _ = self._running_groups.pop(group_id, None)
                 _ = self._running_groups.pop(group_id, None)
+                self._pending_group_assembled.set()
+
+    def update_tensors(self, tensors: Sequence[torch.Tensor], *, add: bool = False) -> bool:
+        """
+        Set or change the values of self.averaged_tensors.
+
+        :param tensors: list/tuple of tensors of same shape as self.averaged_tensors
+        :param add: if True, add tensors to self.averaged_tensors in-place
+          by default, simply write the values of :tensors: to self.averaged_tensors
+        :note: if there may be updates running in background, it is recommended to use add=True
+        """
+        assert len(tensors) == len(self._averaged_tensors)
+        with torch.no_grad(), self.lock_averaged_tensors:
+            for tensor, update in zip(self._averaged_tensors, tensors):
+                if add:
+                    tensor += update
+                else:
+                    tensor[...] = update
+        return True
+
+    @contextlib.contextmanager
+    def get_tensors(self) -> Sequence[torch.Tensor]:
+        """
+        A contextmanager that gives user access to averaged tensors.
+        It is guaranteed that the averager will not modify tensors while this context is active.
+
+        Example:
+              >>> with averager.get_tensors() as tensors:
+              >>>     update_model(tensors)
+              >>>     tensors[0] += 1
+              >>> # do not use tensors after the lock is acquired
+        """
+        with self.lock_averaged_tensors:
+            yield self._averaged_tensors
 
 
     async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
     async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
                              ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
                              ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
@@ -199,16 +252,22 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         async for response in self._matchmaking.rpc_join_group(request, context):
         async for response in self._matchmaking.rpc_join_group(request, context):
             yield response
             yield response
 
 
-    async def rpc_aggregate_part(self, request: averaging_pb2.AveragingData, context: grpc.ServicerContext):
+    async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
+                                 ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
         """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
-        if request.group_id not in self._running_groups and not self._pending_group_assembled.is_set():
+        request = await anext(stream)
+        if request.group_id not in self._running_groups:
             # this handles a special case when leader accepted us to group AND began allreduce right away,
             # this handles a special case when leader accepted us to group AND began allreduce right away,
             # but his response with group_id was delayed and other peers got to us first
             # but his response with group_id was delayed and other peers got to us first
             await self._pending_group_assembled.wait()
             await self._pending_group_assembled.wait()
-        if request.group_id not in self._running_groups:
-            return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
-        else:
-            return await self._running_groups[request.group_id].rpc_aggregate_part(request, context)
+
+        group = self._running_groups.get(request.group_id)
+        if group is None:
+            yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+            return
+
+        async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
+            yield message
 
 
 
 
 def is_power_of_two(n):
 def is_power_of_two(n):

+ 89 - 40
hivemind/client/averaging/allreduce.py

@@ -1,10 +1,11 @@
 import asyncio
 import asyncio
-from typing import Sequence, Set, Dict, Tuple
+from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Iterator
 
 
 import grpc
 import grpc
 import torch
 import torch
 
 
-from hivemind.utils import Endpoint, get_logger, serialize_torch_tensor, deserialize_torch_tensor, ChannelCache
+from hivemind.utils import Endpoint, get_logger, ChannelCache, anext
+from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
 from hivemind.proto import averaging_pb2_grpc, runtime_pb2, averaging_pb2
 from hivemind.proto import averaging_pb2_grpc, runtime_pb2, averaging_pb2
 
 
 # flavour types
 # flavour types
@@ -19,25 +20,32 @@ class AllReduceProtocol:
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param endpoint: your endpoint, must be included in ordered_group_endpoints
     :param endpoint: your endpoint, must be included in ordered_group_endpoints
     :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
     :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
+    :param return_deltas: if True, returns the element-wise differences (averaged_tensors - original_tensors)
+           default (False) - return averaged_tensors by themselves
     """
     """
+
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint]):
+                 ordered_group_endpoints: Sequence[Endpoint], return_deltas: bool = False):
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
         self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
         self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, self.group_size)))
         self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, self.group_size)))
         self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
         self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
+        self.return_deltas = return_deltas
 
 
         self.accumulator = self.local_tensor_parts[self.endpoint].clone()  # sum inputs from peers to this tensor
         self.accumulator = self.local_tensor_parts[self.endpoint].clone()  # sum inputs from peers to this tensor
         self.accumulated_from: Set[Endpoint] = {self.endpoint}  # peers that we have accumulated our part from
         self.accumulated_from: Set[Endpoint] = {self.endpoint}  # peers that we have accumulated our part from
         self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
         self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
         self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
         self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
-        self.averaged_tensors: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
+        self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
 
 
     def __repr__(self):
     def __repr__(self):
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
 
 
     def __await__(self):
     def __await__(self):
-        return self.averaged_tensors.__await__()
+        return self.future.__await__()
+
+    def __contains__(self, endpoint: Endpoint):
+        return endpoint in self.local_tensor_parts
 
 
     @property
     @property
     def group_size(self):
     def group_size(self):
@@ -46,7 +54,7 @@ class AllReduceProtocol:
     async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor) -> torch.Tensor:
     async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor) -> torch.Tensor:
         """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
         """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
         assert not self.averaged_part.done(), f"already finished averaging part: {self.averaged_part}"
         assert not self.averaged_part.done(), f"already finished averaging part: {self.averaged_part}"
-        assert not self.averaged_tensors.done(), f"already finished allreduce: {self.averaged_tensors}"
+        assert not self.future.done(), f"already finished allreduce: {self.future}"
         assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
         assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
         assert source not in self.accumulated_from, "duplicate source, already received that part"
         assert source not in self.accumulated_from, "duplicate source, already received that part"
         logger.debug(f"{self} - accumulating tensor part from {source}")
         logger.debug(f"{self} - accumulating tensor part from {source}")
@@ -63,7 +71,7 @@ class AllReduceProtocol:
         return await self.averaged_part
         return await self.averaged_part
 
 
     def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
     def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
-        assert not self.averaged_tensors.done(), f"already finished allreduce: {self.averaged_tensors}"
+        assert not self.future.done(), f"already finished allreduce: {self.future}"
         assert source in self.local_tensor_parts, "the provider of averaged part is not from my group"
         assert source in self.local_tensor_parts, "the provider of averaged part is not from my group"
         assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
         assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
         assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
         assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
@@ -72,28 +80,37 @@ class AllReduceProtocol:
         self.averaged_tensor_parts[source] = averaged_part
         self.averaged_tensor_parts[source] = averaged_part
         if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
         if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
             ordered_averaged_parts = [self.averaged_tensor_parts[endpoint] for endpoint in self.ordered_group_endpoints]
             ordered_averaged_parts = [self.averaged_tensor_parts[endpoint] for endpoint in self.ordered_group_endpoints]
-            self.averaged_tensors.set_result(restore_from_parts(ordered_averaged_parts, self.tensor_shapes))
+            outputs = restore_from_parts(ordered_averaged_parts, self.tensor_shapes)
+
+            if self.return_deltas:
+                local_parts = [self.local_tensor_parts[peer] for peer in self.ordered_group_endpoints]
+                with torch.no_grad():
+                    original_tensors = restore_from_parts(local_parts, self.tensor_shapes)
+                    for averaged_tensor, original_tensor in zip(outputs, original_tensors):
+                        averaged_tensor -= original_tensor
+
+            self.future.set_result(outputs)
 
 
     def cancel(self) -> bool:
     def cancel(self) -> bool:
-        if not self.averaged_tensors.done():
+        if not self.future.done():
             logger.debug(f"{self} - cancelled")
             logger.debug(f"{self} - cancelled")
-            self.averaged_tensors.cancel()
+            self.future.cancel()
             if not self.averaged_part.done():
             if not self.averaged_part.done():
                 self.averaged_part.cancel()
                 self.averaged_part.cancel()
             return True
             return True
         else:
         else:
-            logger.debug(f"{self} - failed to cancel, allreduce is already finished: {self.averaged_tensors}")
+            logger.debug(f"{self} - failed to cancel, allreduce is already finished: {self.future}")
             return False
             return False
 
 
     def set_exception(self, exception: Exception) -> bool:
     def set_exception(self, exception: Exception) -> bool:
-        if not self.averaged_tensors.done():
+        if not self.future.done():
             logger.debug(f"{self} - {exception}")
             logger.debug(f"{self} - {exception}")
-            self.averaged_tensors.set_exception(exception)
+            self.future.set_exception(exception)
             if not self.averaged_part.done():
             if not self.averaged_part.done():
                 self.averaged_part.cancel()
                 self.averaged_part.cancel()
             return True
             return True
         else:
         else:
-            logger.debug(f"{self} - failed to set {exception}, allreduce already finished: {self.averaged_tensors}")
+            logger.debug(f"{self} - failed to set {exception}, allreduce already finished: {self.future}")
             return False
             return False
 
 
 
 
@@ -101,11 +118,14 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
     """
     """
     A class that implements ButterflyAllReduceProtocol on top of a gRPC servicer
     A class that implements ButterflyAllReduceProtocol on top of a gRPC servicer
     """
     """
+
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType):
+                 ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
+                 chunk_size_bytes: int, return_deltas: bool = False):
         super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint,
         super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint,
-                         ordered_group_endpoints=ordered_group_endpoints)
-        self.compression_type = compression_type
+                         ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
+        self.compression_type, self.chunk_size_bytes = compression_type, chunk_size_bytes
+        self.averaged_part_stream: asyncio.Future[Tuple[runtime_pb2.Tensor, ...]] = asyncio.Future()
 
 
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
@@ -113,55 +133,84 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
     async def _average_one_part(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
     async def _average_one_part(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
         """ Send one part of local tensors to one groupmate and collect the average for this part """
         """ Send one part of local tensors to one groupmate and collect the average for this part """
         serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
         serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
-        response = await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(
-            averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id,
-                                        endpoint=self.endpoint, tensor_part=serialized_tensor_part))
-        if response.code == averaging_pb2.AVERAGED_PART:
-            averaged_part = deserialize_torch_tensor(response.tensor_part)
-            self.register_averaged_part(peer_endpoint, averaged_part)
-            return averaged_part
-        else:
-            raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(response.code)}"
+        chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes)
+
+        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
+        await stream.write(averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id,
+                                                       endpoint=self.endpoint, tensor_part=next(chunks)))
+        for chunk in chunks:
+            await stream.write(averaging_pb2.AveragingData(tensor_part=chunk))
+        await stream.done_writing()
+
+        outputs: Sequence[averaging_pb2.AveragingData] = [message async for message in stream]
+        code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR
+        if code != averaging_pb2.AVERAGED_PART:
+            raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}"
                                      f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
                                      f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
                                      f" allreduce failed")
                                      f" allreduce failed")
 
 
+        averaged_part = deserialize_torch_tensor(combine_from_streaming([message.tensor_part for message in outputs]))
+        self.register_averaged_part(peer_endpoint, averaged_part)
+        return averaged_part
+
     async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
     async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
-        await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
-            group_id=self.group_id, endpoint=self.endpoint, code=code))
+        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
+        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
+        await stream.done_writing()
 
 
     async def run(self) -> Sequence[torch.Tensor]:
     async def run(self) -> Sequence[torch.Tensor]:
-        """ send allreduce requests to all peers and collect results, return the averaged tensor """
+        """
+        send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
+        """
         try:
         try:
             await asyncio.gather(self, *(self._average_one_part(peer, part)
             await asyncio.gather(self, *(self._average_one_part(peer, part)
                                          for peer, part in self.local_tensor_parts.items() if peer != self.endpoint))
                                          for peer, part in self.local_tensor_parts.items() if peer != self.endpoint))
             return await self
             return await self
-        except Exception as e:
+        except BaseException as e:
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
             self.set_exception(e)
             self.set_exception(e)
             for peer_endpoint in self.ordered_group_endpoints:
             for peer_endpoint in self.ordered_group_endpoints:
-                asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
+                if peer_endpoint != self.endpoint:
+                    asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
             raise
             raise
 
 
-    async def rpc_aggregate_part(self, request: averaging_pb2.AveragingData, context: grpc.ServicerContext):
+    async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
+                                        ) -> Iterable[runtime_pb2.Tensor]:
+        """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
+        tensor_part: torch.Tensor = deserialize_torch_tensor(combine_from_streaming(stream_messages))
+        averaged_part = await self.accumulate_part(source, tensor_part)
+        if not self.averaged_part_stream.done():
+            serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False)
+            stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
+            self.averaged_part_stream.set_result(stream_chunks)
+            return stream_chunks
+        else:
+            return self.averaged_part_stream.result()
+
+    async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
+                                 ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
         """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
+        request: averaging_pb2.AveragingData = await anext(stream)
+
         if request.group_id != self.group_id:
         if request.group_id != self.group_id:
-            return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+            yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
 
 
-        if request.code == averaging_pb2.PART_FOR_AVERAGING:
+        elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
             try:
-                tensor_part = deserialize_torch_tensor(request.tensor_part)
-                averaged_part = await self.accumulate_part(request.endpoint, tensor_part)
-                serialized = serialize_torch_tensor(averaged_part, request.tensor_part.compression, allow_inplace=False)
-                return averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized)
+                tensor_chunks = (request.tensor_part, *[msg.tensor_part async for msg in stream])
+                averaged_chunks = iter(await self.accumulate_part_streaming(request.endpoint, tensor_chunks))
+                yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=next(averaged_chunks))
+                for averaged_chunk in averaged_chunks:
+                    yield averaging_pb2.AveragingData(tensor_part=averaged_chunk)
             except Exception as e:
             except Exception as e:
                 self.set_exception(e)
                 self.set_exception(e)
-                return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+                yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
         else:
         else:
             error_code = averaging_pb2.MessageCode.Name(request.code)
             error_code = averaging_pb2.MessageCode.Name(request.code)
             logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
             logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
             self.set_exception(AllreduceException(f"peer {request.endpoint} sent {error_code}."))
             self.set_exception(AllreduceException(f"peer {request.endpoint} sent {error_code}."))
-            return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+            yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
 
 
 
 
 def split_into_parts(tensors: Sequence[torch.Tensor], group_size: int) -> Tuple[torch.Tensor, ...]:
 def split_into_parts(tensors: Sequence[torch.Tensor], group_size: int) -> Tuple[torch.Tensor, ...]:

+ 196 - 138
hivemind/client/averaging/matchmaking.py

@@ -6,7 +6,7 @@ import contextlib
 import random
 import random
 from dataclasses import asdict
 from dataclasses import asdict
 from math import isfinite
 from math import isfinite
-from typing import Sequence, Optional, AsyncIterator, Set
+from typing import Sequence, Optional, AsyncIterator, Set, Tuple
 import asyncio
 import asyncio
 
 
 import torch
 import torch
@@ -27,29 +27,41 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
     f"""
     f"""
     An internal class that is used to form groups of averages for running allreduce
     An internal class that is used to form groups of averages for running allreduce
     See DecentralizedAverager docstring for the detailed description of all parameters
     See DecentralizedAverager docstring for the detailed description of all parameters
+    
+    :note: on implementation: the current matchmaker protocol can encounter one type of (temporary) deadlock;
+      This deadlock occurs when averager A requests averager B at the same time as averager B requests averager A.
+      In that case, neither averager can process the other one's request because it is awaiting lock_request_join_group.
+      This deadlock only happens if averagers have outdated information on expirations (due to network delays). 
+      While A->B->A deadlock is easy to fix, it gets much harder with more peers (e.g. A -> B -> C -> D -> A).
+      Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
+    
     """
     """
 
 
     def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *,
     def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *,
                  prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
                  prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15, compression_type: runtime_pb2.CompressionType = runtime_pb2.NONE):
+                 averaging_expiration: float = 15, request_timeout: float, **allreduce_kwargs):
         assert '.' not in prefix, "group prefix must be a string without ."
         assert '.' not in prefix, "group prefix must be a string without ."
+        if request_timeout is None or request_timeout >= averaging_expiration:
+            logger.warning("It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
+                           "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring.")
 
 
         super().__init__()
         super().__init__()
         self.dht, self.endpoint, self.averaged_tensors = dht, endpoint, tuple(averaged_tensors)
         self.dht, self.endpoint, self.averaged_tensors = dht, endpoint, tuple(averaged_tensors)
         self.prefix, self.group_bits = prefix, initial_group_bits
         self.prefix, self.group_bits = prefix, initial_group_bits
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
-        self.averaging_expiration, self.compression_type = averaging_expiration, compression_type
-
+        self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
+        self.allreduce_kwargs = allreduce_kwargs
         self.schema_hash = compute_schema_hash(self.averaged_tensors)
         self.schema_hash = compute_schema_hash(self.averaged_tensors)
 
 
         self.lock_looking_for_group = asyncio.Lock()
         self.lock_looking_for_group = asyncio.Lock()
         self.lock_request_join_group = asyncio.Lock()
         self.lock_request_join_group = asyncio.Lock()
-        self.cond_notify_followers = asyncio.Condition()
+        self.follower_was_discarded = asyncio.Event()
+        self.was_accepted_to_group = asyncio.Event()
         self.assembled_group = asyncio.Future()
         self.assembled_group = asyncio.Future()
 
 
         self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
         self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Set[Endpoint] = set()  # iff i am a leader, this contains my followers excluding myself
         self.current_followers: Set[Endpoint] = set()  # iff i am a leader, this contains my followers excluding myself
-        self.potential_leaders = PotentialLeaders(self.endpoint, self.dht, self.averaging_expiration)
+        self.potential_leaders = PotentialLeaders(endpoint, dht, averaging_expiration, target_group_size)
 
 
     @property
     @property
     def is_looking_for_group(self):
     def is_looking_for_group(self):
@@ -70,7 +82,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
         return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
                f" current key = {self.current_group_key})"
                f" current key = {self.current_group_key})"
 
 
-    async def look_for_group(self, *, timeout: Optional[float] = None) -> AllReduceRunner:
+    async def look_for_group(self, *, timeout: Optional[float] = None) -> Optional[AllReduceRunner]:
         """
         """
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
@@ -82,48 +94,58 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
             request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
             try:
             try:
                 return await asyncio.wait_for(self.assembled_group, timeout=timeout)
                 return await asyncio.wait_for(self.assembled_group, timeout=timeout)
-            except Exception as e:
+            except asyncio.TimeoutError:
+                return None
+
+            except BaseException as e:
                 if len(self.current_followers) > 0:
                 if len(self.current_followers) > 0:
                     async with self.lock_request_join_group:
                     async with self.lock_request_join_group:
                         await self.leader_disband_group()
                         await self.leader_disband_group()
-                self.assembled_group.set_exception(e)
+                if not self.assembled_group.done():
+                    self.assembled_group.set_exception(e)
                 raise
                 raise
 
 
             finally:
             finally:
                 if not request_leaders_task.done():
                 if not request_leaders_task.done():
                     request_leaders_task.cancel()
                     request_leaders_task.cancel()
-                if self.assembled_group.done():
-                    self.assembled_group = asyncio.Future()
+                if not self.assembled_group.done():
+                    self.assembled_group.cancel()
+                while len(self.current_followers) > 0:
+                    await self.follower_was_discarded.wait()
+                    self.follower_was_discarded.clear()
+                # note: the code above ensures that we send all followers away before creating new future
+                self.assembled_group = asyncio.Future()
+                self.was_accepted_to_group.clear()
 
 
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
         """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
         """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
-        end_time = get_dht_time() + timeout if timeout is not None else float('inf')
         async with self.potential_leaders.begin_search(self.current_group_key, timeout):
         async with self.potential_leaders.begin_search(self.current_group_key, timeout):
             # TODO update group_bits on success! reduce number of bits on not enough peers.
             # TODO update group_bits on success! reduce number of bits on not enough peers.
             # TODO after allreduce finishes, we may need to ask leader to notify lower keys about this
             # TODO after allreduce finishes, we may need to ask leader to notify lower keys about this
             # (so as to fix possible network partitioning if some peers operate on a much smaller nbits)
             # (so as to fix possible network partitioning if some peers operate on a much smaller nbits)
             while True:
             while True:
                 try:
                 try:
-                    time_to_expiration = self.potential_leaders.declared_expiration_time - get_dht_time()
-                    next_best_leader = await asyncio.wait_for(
-                        self.potential_leaders.pop_next_leader(),
-                        timeout=time_to_expiration if isfinite(time_to_expiration) else None)
-
-                    request_expiration_time = min(self.potential_leaders.declared_expiration_time,
-                                                  end_time, get_dht_time() + self.averaging_expiration)
-                    group = await self.request_join_group(next_best_leader, request_expiration_time)
+                    next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
+
+                    group = await self.request_join_group(next_leader, self.potential_leaders.request_expiration_time)
                     if group is not None:
                     if group is not None:
                         return group
                         return group
 
 
                 except asyncio.TimeoutError:
                 except asyncio.TimeoutError:
                     async with self.lock_request_join_group:
                     async with self.lock_request_join_group:
-                        if len(self.current_followers) >= self.min_group_size:
+                        if self.assembled_group.done():
+                            return self.assembled_group.result()
+                        elif len(self.current_followers) + 1 >= self.min_group_size:
                             # the time is up, we have a *good enough* group. run allreduce as is.
                             # the time is up, we have a *good enough* group. run allreduce as is.
                             return await self.leader_assemble_group()
                             return await self.leader_assemble_group()
-                        else:
+                        elif len(self.current_followers) > 0:
                             await self.leader_disband_group()
                             await self.leader_disband_group()
                             # TODO maybe adjust grid size
                             # TODO maybe adjust grid size
-                            continue
+                        continue
+                except Exception as e:
+                    if not self.assembled_group.done():
+                        self.assembled_group.set_exception(e)
+                    raise e
 
 
     async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[AllReduceRunner]:
     async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[AllReduceRunner]:
         """
         """
@@ -134,87 +156,101 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
           The originally specified leader can disband group and redirect us to a different leader
           The originally specified leader can disband group and redirect us to a different leader
         """
         """
         assert self.is_looking_for_group and self.current_leader is None
         assert self.is_looking_for_group and self.current_leader is None
-        call: Optional[grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]] = None
+        call: Optional[grpc.aio.UnaryStreamCall] = None
         try:
         try:
             async with self.lock_request_join_group:
             async with self.lock_request_join_group:
                 leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
                 leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
                 call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
                 call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
                     endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time))
                     endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time))
+                message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
 
 
-                message = await call.read()
-                if message.code != averaging_pb2.ACCEPTED:
-                    code = averaging_pb2.MessageCode.Name(message.code)
-                    logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
-                    return None
+                if message.code == averaging_pb2.ACCEPTED:
+                    logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
+                    self.current_leader = leader
+                    self.was_accepted_to_group.set()
+                    if len(self.current_followers) > 0:
+                        await self.leader_disband_group()
 
 
-                # else: we were accepted
-                logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
-                self.current_leader = leader
-                if len(self.current_followers) > 0:
-                    await self.leader_disband_group()
+            if message.code != averaging_pb2.ACCEPTED:
+                code = averaging_pb2.MessageCode.Name(message.code)
+                logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
+                return None
 
 
             async with self.potential_leaders.pause_search():
             async with self.potential_leaders.pause_search():
-                message = await call.read()
+                time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
+                message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout)
 
 
-            if message.code == averaging_pb2.BEGIN_ALLREDUCE:
-                async with self.lock_request_join_group:
-                    return await self.follower_assemble_group(leader, message.group_id, message.ordered_group_endpoints)
-            elif message.code == averaging_pb2.GROUP_DISBANDED and bool(message.suggested_leader):
-                logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
-                return await self.request_join_group(message.suggested_leader, expiration_time)
+                if message.code == averaging_pb2.BEGIN_ALLREDUCE:
+                    async with self.lock_request_join_group:
+                        return await self.follower_assemble_group(
+                            leader, message.group_id, message.ordered_group_endpoints)
+
+            if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
+                if message.suggested_leader and message.suggested_leader != self.endpoint:
+                    logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
+                    self.current_leader = None
+                    call.cancel()
+                    return await self.request_join_group(message.suggested_leader, expiration_time)
+                else:
+                    logger.debug(f"{self} - leader disbanded group")
+                    return None
 
 
-            else:
-                logger.debug(f"{self} - leader sent {averaging_pb2.MessageCode.Name(message.code)}, leaving group")
-                return None
+            logger.debug(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}")
+            return None
+        except asyncio.TimeoutError:
+            logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
+            if call is not None:
+                call.cancel()
+            return None
         finally:
         finally:
+            self.was_accepted_to_group.clear()
             self.current_leader = None
             self.current_leader = None
             if call is not None:
             if call is not None:
-                call.cancel()
+                await call.code()
 
 
     async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
     async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
                              ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
                              ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """
         """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """
         try:
         try:
-            reason_to_reject = self._check_reasons_to_reject(request)
-            if reason_to_reject is not None:
-                yield reason_to_reject
-                return
-
-            current_group = self.assembled_group  # copy current assembled_group to avoid overwriting
             async with self.lock_request_join_group:
             async with self.lock_request_join_group:
+                reason_to_reject = self._check_reasons_to_reject(request)
+                if reason_to_reject is not None:
+                    yield reason_to_reject
+                    return
+
                 self.current_followers.add(request.endpoint)
                 self.current_followers.add(request.endpoint)
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
 
-                if len(self.current_followers) + 1 >= self.target_group_size:
+                if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
                     # outcome 1: we have assembled a full group and are ready for allreduce
                     # outcome 1: we have assembled a full group and are ready for allreduce
                     await self.leader_assemble_group()
                     await self.leader_assemble_group()
 
 
-            if not current_group.done():
-                try:
-                    async with self.cond_notify_followers:
-                        # wait for the group to be assembled or disbanded
-                        timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time())
-                        await asyncio.wait_for(self.cond_notify_followers.wait(), timeout=timeout)
-                except asyncio.TimeoutError:
-                    async with self.lock_request_join_group:
+            # wait for the group to be assembled or disbanded
+            timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time())
+            await asyncio.wait({self.assembled_group, self.was_accepted_to_group.wait()},
+                               return_when=asyncio.FIRST_COMPLETED, timeout=timeout)
+            if not self.assembled_group.done() and not self.was_accepted_to_group.is_set():
+                async with self.lock_request_join_group:
+                    if self.assembled_group.done():
+                        pass  # this covers a rare case when the group is assembled while the event loop was busy.
+                    elif len(self.current_followers) + 1 >= self.min_group_size and self.is_looking_for_group:
                         # outcome 2: the time is up, run allreduce with what we have or disband
                         # outcome 2: the time is up, run allreduce with what we have or disband
-                        if len(self.current_followers) + 1 >= self.min_group_size and self.is_looking_for_group:
-                            await self.leader_assemble_group()
-                        else:
-                            await self.leader_disband_group()
-
-            if self.current_leader is not None:
-                # outcome 3: found by a leader with higher priority, send our followers to him
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED,
-                                                      suggested_leader=self.current_leader)
-                return
+                        await self.leader_assemble_group()
+                    else:
+                        await self.leader_disband_group()
 
 
-            if request.endpoint not in self.current_followers:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
-                return
+            if self.was_accepted_to_group.is_set() or not self.assembled_group.done() \
+                    or self.assembled_group.cancelled() or request.endpoint not in self.assembled_group.result():
+                if self.current_leader is not None:
+                    # outcome 3: found by a leader with higher priority, send our followers to him
+                    yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED,
+                                                          suggested_leader=self.current_leader)
+                    return
+                else:
+                    yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
+                    return
 
 
-            # finally, run allreduce
-            allreduce_group = current_group.result()
+            allreduce_group = self.assembled_group.result()
             yield averaging_pb2.MessageFromLeader(
             yield averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
                 code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
                 ordered_group_endpoints=allreduce_group.ordered_group_endpoints)
                 ordered_group_endpoints=allreduce_group.ordered_group_endpoints)
@@ -225,10 +261,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
             self.current_followers.discard(request.endpoint)
             self.current_followers.discard(request.endpoint)
+            self.follower_was_discarded.set()
 
 
-    def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> averaging_pb2.MessageFromLeader:
+    def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> Optional[averaging_pb2.MessageFromLeader]:
         """ :returns: if accepted, return None, otherwise return a reason for rejection """
         """ :returns: if accepted, return None, otherwise return a reason for rejection """
-        if not self.is_looking_for_group:
+        if not self.is_looking_for_group or self.assembled_group.done():
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
 
 
         if request.ListFields() == 3 and not isinstance(request.schema_hash, bytes) or len(request.schema_hash) == 0 \
         if request.ListFields() == 3 and not isinstance(request.schema_hash, bytes) or len(request.schema_hash) == 0 \
@@ -243,8 +280,8 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         elif self.potential_leaders.declared_expiration_time > (request.expiration or float('inf')):
         elif self.potential_leaders.declared_expiration_time > (request.expiration or float('inf')):
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
         elif self.current_leader is not None:
         elif self.current_leader is not None:
-            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER,
-                                                   suggested_leader=self.current_leader)
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader
+                                                   )  # note: this suggested leader is currently ignored
         elif request.endpoint == self.endpoint or request.endpoint in self.current_followers:
         elif request.endpoint == self.endpoint or request.endpoint in self.current_followers:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
         elif len(self.current_followers) + 1 >= self.target_group_size:
         elif len(self.current_followers) + 1 >= self.target_group_size:
@@ -255,68 +292,71 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
     async def leader_assemble_group(self) -> AllReduceRunner:
     async def leader_assemble_group(self) -> AllReduceRunner:
         """ Form up all current followers into a group and prepare to _run_allreduce """
         """ Form up all current followers into a group and prepare to _run_allreduce """
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
+        assert not self.assembled_group.done()
         group_id = DHTID.generate().to_bytes()
         group_id = DHTID.generate().to_bytes()
         ordered_group_endpoints = list(self.current_followers)
         ordered_group_endpoints = list(self.current_followers)
         ordered_group_endpoints.append(self.endpoint)
         ordered_group_endpoints.append(self.endpoint)
         random.shuffle(ordered_group_endpoints)
         random.shuffle(ordered_group_endpoints)
-        logger.debug(f"{self.endpoint} - leader started allreduce with {len(ordered_group_endpoints)} followers.")
-        allreduce_group = AllReduceRunner(
-            group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
-            ordered_group_endpoints=ordered_group_endpoints, compression_type=self.compression_type)
+        logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
+        allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
+                                          ordered_group_endpoints=ordered_group_endpoints, **self.allreduce_kwargs)
         self.assembled_group.set_result(allreduce_group)
         self.assembled_group.set_result(allreduce_group)
-        async with self.cond_notify_followers:
-            self.cond_notify_followers.notify_all()
         return allreduce_group
         return allreduce_group
 
 
     async def follower_assemble_group(self, leader: Endpoint, group_id: GroupID,
     async def follower_assemble_group(self, leader: Endpoint, group_id: GroupID,
                                       ordered_group_endpoints: Sequence[Endpoint]) -> AllReduceRunner:
                                       ordered_group_endpoints: Sequence[Endpoint]) -> AllReduceRunner:
         """ Prepare to run allreduce using a list of peers provided by our leader """
         """ Prepare to run allreduce using a list of peers provided by our leader """
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
+        assert not self.assembled_group.done()
         logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
         logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
         assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
         assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
-        allreduce_group = AllReduceRunner(
-            group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
-            ordered_group_endpoints=ordered_group_endpoints, compression_type=self.compression_type)
+        allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
+                                          ordered_group_endpoints=ordered_group_endpoints, **self.allreduce_kwargs)
         self.assembled_group.set_result(allreduce_group)
         self.assembled_group.set_result(allreduce_group)
-        async with self.cond_notify_followers:
-            self.cond_notify_followers.notify_all()
         return allreduce_group
         return allreduce_group
 
 
     async def leader_disband_group(self):
     async def leader_disband_group(self):
         """ Kick out all followers immediately, optionally direct them to our new leader (if we found one) """
         """ Kick out all followers immediately, optionally direct them to our new leader (if we found one) """
         assert self.lock_request_join_group.locked()
         assert self.lock_request_join_group.locked()
         self.current_followers.clear()  # this will cause rpc_join_group to kick all followers out
         self.current_followers.clear()  # this will cause rpc_join_group to kick all followers out
-        async with self.cond_notify_followers:
-            self.cond_notify_followers.notify_all()
 
 
 
 
 class PotentialLeaders:
 class PotentialLeaders:
     """ An utility class that searches for averagers that could become our leaders """
     """ An utility class that searches for averagers that could become our leaders """
-    def __init__(self, endpoint: Endpoint, dht: hivemind.DHT, averaging_expiration: DHTExpiration):
+
+    def __init__(self, endpoint: Endpoint, dht: hivemind.DHT, averaging_expiration: DHTExpiration,
+                 target_group_size: Optional[int]):
         self.endpoint, self.dht, self.averaging_expiration = endpoint, dht, averaging_expiration
         self.endpoint, self.dht, self.averaging_expiration = endpoint, dht, averaging_expiration
+        self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
+        self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
         self.leader_queue = TimedStorage[Endpoint, DHTExpiration]()
         self.leader_queue = TimedStorage[Endpoint, DHTExpiration]()
-        self.max_assured_time = float('-inf')
+        self.past_attempts: Set[Tuple[Endpoint, DHTExpiration]] = set()
         self.declared_expiration_time = float('inf')
         self.declared_expiration_time = float('inf')
         self.declared_group_key: Optional[GroupKey] = None
         self.declared_group_key: Optional[GroupKey] = None
+        self.max_assured_time = float('-inf')
         self.search_end_time = float('inf')
         self.search_end_time = float('inf')
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
     async def begin_search(self, group_key: GroupKey, timeout: Optional[float]):
     async def begin_search(self, group_key: GroupKey, timeout: Optional[float]):
-        assert not self.running.is_set(), "already running"
-        self.running.set()
-        self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
-        update_queue_task = asyncio.create_task(self._update_queue_periodically(group_key))
-        declare_averager_task = asyncio.create_task(self._declare_averager_periodically(group_key))
-        try:
-            yield self
-        finally:
-            update_queue_task.cancel()
-            declare_averager_task.cancel()
-            self.running.clear()
-            self.update_triggered.clear()
-            self.update_finished.clear()
+        async with self.lock_search:
+            self.running.set()
+            self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
+            update_queue_task = asyncio.create_task(self._update_queue_periodically(group_key))
+            declare_averager_task = asyncio.create_task(self._declare_averager_periodically(group_key))
+            try:
+                yield self
+            finally:
+                if not update_queue_task.done():
+                    update_queue_task.cancel()
+                if not declare_averager_task.done():
+                    declare_averager_task.cancel()
+                for field in (self.past_attempts, self.leader_queue, self.running,
+                              self.update_finished, self.update_triggered, self.declared_expiration):
+                    field.clear()
+                self.max_assured_time = float('-inf')
+                self.search_end_time = float('inf')
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
     async def pause_search(self):
     async def pause_search(self):
@@ -332,19 +372,34 @@ class PotentialLeaders:
 
 
     async def pop_next_leader(self) -> Endpoint:
     async def pop_next_leader(self) -> Endpoint:
         """ Remove and return the next most suitable leader or throw an exception if reached timeout """
         """ Remove and return the next most suitable leader or throw an exception if reached timeout """
-        assert self.running, "Not running search at the moment"
-        maybe_next_leader, entry = self.leader_queue.top()
-
-        next_entry_time = entry.expiration_time if maybe_next_leader is not None else get_dht_time()
-        if self.max_assured_time < next_entry_time < self.search_end_time:
-            self.update_triggered.set()
+        assert self.running.is_set(), "Not running search at the moment"
+        while True:
+            maybe_next_leader, entry = self.leader_queue.top()
+
+            if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
+                self.update_triggered.set()
+
+            if maybe_next_leader is None or entry.expiration_time >= self.declared_expiration_time:
+                await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()},
+                                   return_when=asyncio.FIRST_COMPLETED)
+                self.declared_expiration.clear()
+                if self.update_finished.is_set():
+                    self.update_finished.clear()
+                    continue
+                else:
+                    raise asyncio.TimeoutError("pop_next_leader was invalidated: re-declared averager in background")
 
 
-        if maybe_next_leader is None:
-            await self.update_finished.wait()
-            return await self.pop_next_leader()
+            del self.leader_queue[maybe_next_leader]
+            self.past_attempts.add((maybe_next_leader, entry.expiration_time))
+            return maybe_next_leader
 
 
-        del self.leader_queue[maybe_next_leader]
-        return maybe_next_leader
+    @property
+    def request_expiration_time(self) -> float:
+        """ this averager's current expiration time - used to send join requests to leaders """
+        if isfinite(self.declared_expiration_time):
+            return self.declared_expiration_time
+        else:
+            return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
 
 
     async def _update_queue_periodically(self, group_key: GroupKey):
     async def _update_queue_periodically(self, group_key: GroupKey):
         DISCREPANCY = hivemind.utils.timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
         DISCREPANCY = hivemind.utils.timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
@@ -352,14 +407,14 @@ class PotentialLeaders:
             new_peers = await self.dht.get_averagers(group_key, only_active=True, return_future=True)
             new_peers = await self.dht.get_averagers(group_key, only_active=True, return_future=True)
             self.max_assured_time = max(self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY)
             self.max_assured_time = max(self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY)
 
 
+            self.leader_queue.clear()
             for peer, peer_expiration_time in new_peers:
             for peer, peer_expiration_time in new_peers:
-                if peer == self.endpoint:
+                if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
                     continue
                     continue
                 self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
                 self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
                 self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
                 self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
 
 
-            if len(self.leader_queue) > 0:
-                self.update_finished.set()
+            self.update_finished.set()
 
 
             await asyncio.wait(
             await asyncio.wait(
                 {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
                 {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
@@ -367,28 +422,31 @@ class PotentialLeaders:
             self.update_triggered.clear()
             self.update_triggered.clear()
 
 
     async def _declare_averager_periodically(self, group_key: GroupKey):
     async def _declare_averager_periodically(self, group_key: GroupKey):
-        try:
-            while True:
-                new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
-                self.declared_group_key, self.declared_expiration_time = group_key, new_expiration_time
-                stored_ok = await self.dht.declare_averager(group_key, self.endpoint, new_expiration_time,
-                                                            looking_for_group=True, return_future=True)
-                if stored_ok:
+        async with self.lock_declare:
+            try:
+                while True:
+                    await self.running.wait()
+
+                    new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
+                    self.declared_group_key, self.declared_expiration_time = group_key, new_expiration_time
+                    self.declared_expiration.set()
+                    await self.dht.declare_averager(group_key, self.endpoint, new_expiration_time,
+                                                    looking_for_group=True, return_future=True)
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
-                else:
-                    logger.warning(f"Failed to subscribe to group {group_key} : store rejected by DHT peers")
-        finally:
-            if self.declared_group_key is not None:
-                previous_declared_key, previous_expiration_time = self.declared_group_key, self.declared_expiration_time
-                self.declared_group_key, self.declared_expiration_time = None, float('inf')
-                self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float('-inf')
-                await self.dht.declare_averager(previous_declared_key, self.endpoint, previous_expiration_time,
-                                                looking_for_group=False, return_future=True)
+            except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
+                logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
+            finally:
+                if self.declared_group_key is not None:
+                    prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
+                    self.declared_group_key, self.declared_expiration_time = None, float('inf')
+                    self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float('-inf')
+                    await self.dht.declare_averager(prev_declared_key, self.endpoint, prev_expiration_time,
+                                                    looking_for_group=False, return_future=True)
 
 
 
 
 def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
 def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
     """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
     """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
     schema_dicts = [{field_name: str(field_value)
     schema_dicts = [{field_name: str(field_value)
-                    for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
+                     for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
                     for tensor in tensors]
                     for tensor in tensors]
     return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()
     return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()

+ 7 - 9
hivemind/dht/__init__.py

@@ -27,7 +27,7 @@ from numpy import nextafter
 from hivemind.client import RemoteExpert
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import get_dht_time, DHTValue
 from hivemind.dht.routing import get_dht_time, DHTValue
-from hivemind.utils import MPFuture, Endpoint, get_logger
+from hivemind.utils import MPFuture, Endpoint, get_logger, switch_to_uvloop
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -141,11 +141,7 @@ class DHT(mp.Process):
 
 
     def run(self) -> None:
     def run(self) -> None:
         """ Serve DHT forever. This function will not return until DHT node is shut down """
         """ Serve DHT forever. This function will not return until DHT node is shut down """
-        if asyncio.get_event_loop().is_running():
-            asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
-        uvloop.install()
-        loop = asyncio.new_event_loop()
-        asyncio.set_event_loop(loop)
+        loop = switch_to_uvloop()
         pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
         pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
 
 
         async def _run():
         async def _run():
@@ -497,13 +493,14 @@ class DHT(mp.Process):
     async def _declare_averager(self, node: DHTNode, *, group_key: str, endpoint: Endpoint,
     async def _declare_averager(self, node: DHTNode, *, group_key: str, endpoint: Endpoint,
                                 expiration_time: DHTExpiration, looking_for_group: bool, future: MPFuture):
                                 expiration_time: DHTExpiration, looking_for_group: bool, future: MPFuture):
         try:
         try:
-            expiration_time = expiration_time if looking_for_group else nextafter(expiration_time, float('inf'))
+            expiration_time = expiration_time if looking_for_group else float(nextafter(expiration_time, float('inf')))
             # ^-- when declaring averager inactive, we increment expiration time to overwrite the pre-existing entry
             # ^-- when declaring averager inactive, we increment expiration time to overwrite the pre-existing entry
             store_ok = await node.store(
             store_ok = await node.store(
                 key=group_key, subkey=endpoint, value=looking_for_group, expiration_time=expiration_time)
                 key=group_key, subkey=endpoint, value=looking_for_group, expiration_time=expiration_time)
             future.set_result(store_ok)
             future.set_result(store_ok)
         except Exception as e:
         except Exception as e:
-            future.set_exception(e)
+            if not future.done():
+                future.set_exception(e)
 
 
     def get_averagers(self, group_key: GroupKey, *, only_active: bool = True, return_future: bool = False
     def get_averagers(self, group_key: GroupKey, *, only_active: bool = True, return_future: bool = False
                       ) -> Union[List[Tuple[Endpoint, DHTExpiration]], MPFuture]:
                       ) -> Union[List[Tuple[Endpoint, DHTExpiration]], MPFuture]:
@@ -534,4 +531,5 @@ class DHT(mp.Process):
                          if not only_active or entry.value is True]
                          if not only_active or entry.value is True]
             future.set_result(averagers)
             future.set_result(averagers)
         except Exception as e:
         except Exception as e:
-            future.set_exception(e)
+            if not future.done():
+                future.set_exception(e)

+ 15 - 5
hivemind/dht/node.py

@@ -357,7 +357,7 @@ class DHTNode:
         """
         """
         if latest:
         if latest:
             kwargs["sufficient_expiration_time"] = float('inf')
             kwargs["sufficient_expiration_time"] = float('inf')
-        result = await self.get_many([key])
+        result = await self.get_many([key], **kwargs)
         return result[key]
         return result[key]
 
 
     async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
     async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
@@ -579,10 +579,20 @@ class _SearchState:
     future: asyncio.Future[Optional[ValueWithExpiration[DHTValue]]] = field(default_factory=asyncio.Future)
     future: asyncio.Future[Optional[ValueWithExpiration[DHTValue]]] = field(default_factory=asyncio.Future)
     serializer: type(SerializerBase) = MSGPackSerializer
     serializer: type(SerializerBase) = MSGPackSerializer
 
 
-    def add_candidate(self, candidate: Optional[ValueWithExpiration[BinaryDHTValue]], source_node_id: Optional[DHTID]):
-        binary_value, expiration_time = candidate or (None, -float('inf'))
-        if not self.finished and expiration_time > (self.expiration_time or -float('inf')):
-            self.binary_value, self.expiration_time, self.source_node_id = binary_value, expiration_time, source_node_id
+    def add_candidate(self, candidate: Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]],
+                      source_node_id: Optional[DHTID]):
+        if self.finished or candidate is None:
+            return
+        elif isinstance(candidate.value, DictionaryDHTValue) and isinstance(self.binary_value, DictionaryDHTValue):
+            self.binary_value.maxsize = max(self.binary_value.maxsize, candidate.value.maxsize)
+            for subkey, subentry in candidate.value.items():
+                self.binary_value.store(subkey, subentry.value, subentry.expiration_time)
+        elif candidate.expiration_time > (self.expiration_time or float('-inf')):
+            self.binary_value = candidate.value
+
+        if candidate.expiration_time > (self.expiration_time or float('-inf')):
+            self.expiration_time = candidate.expiration_time
+            self.source_node_id = source_node_id
             if self.expiration_time >= self.sufficient_expiration_time:
             if self.expiration_time >= self.sufficient_expiration_time:
                 self.finish_search()
                 self.finish_search()
 
 

+ 1 - 1
hivemind/dht/protocol.py

@@ -44,7 +44,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """
         """
         self = cls(_initialized_with_create=True)
         self = cls(_initialized_with_create=True)
         self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
         self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
-        self.wait_timeout, self.channel_options = wait_timeout, channel_options
+        self.wait_timeout, self.channel_options = wait_timeout, tuple(channel_options)
         self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
         self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))

+ 1 - 1
hivemind/proto/averaging.proto

@@ -5,7 +5,7 @@ import "runtime.proto";
 // Runs alongside each trainer to perform gating function averaging every now and then. Read more: client/averaging.py
 // Runs alongside each trainer to perform gating function averaging every now and then. Read more: client/averaging.py
 service DecentralizedAveraging {
 service DecentralizedAveraging {
   rpc rpc_join_group(JoinRequest) returns (stream MessageFromLeader);  // assemble a group for allreduce
   rpc rpc_join_group(JoinRequest) returns (stream MessageFromLeader);  // assemble a group for allreduce
-  rpc rpc_aggregate_part(AveragingData) returns (AveragingData);  // send my local shard => get aggregated shard
+  rpc rpc_aggregate_part(stream AveragingData) returns (stream AveragingData);  // send local part => get average part
 }
 }
 
 
 enum MessageCode {
 enum MessageCode {

+ 1 - 0
hivemind/proto/runtime.proto

@@ -38,5 +38,6 @@ message Tensor {
   bool requires_grad = 3;
   bool requires_grad = 3;
   string dtype = 4;
   string dtype = 4;
   CompressionType compression = 5;
   CompressionType compression = 5;
+  int32 chunks = 6;
 }
 }
 
 

+ 1 - 0
hivemind/utils/__init__.py

@@ -7,3 +7,4 @@ from hivemind.utils.threading import *
 from hivemind.utils.grpc import *
 from hivemind.utils.grpc import *
 from hivemind.utils.timed_storage import *
 from hivemind.utils.timed_storage import *
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
+from hivemind.utils.asyncio import *

+ 34 - 0
hivemind/utils/asyncio.py

@@ -0,0 +1,34 @@
+from typing import TypeVar, AsyncIterator, Union, AsyncIterable
+import asyncio
+import uvloop
+T = TypeVar('T')
+
+
+def switch_to_uvloop() -> asyncio.AbstractEventLoop:
+    """ stop any running event loops; install uvloop; then create, set and return a new event loop """
+    try:
+        asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
+    except RuntimeError as error_no_event_loop:
+        pass  # this allows running DHT from background threads with no event loop
+    uvloop.install()
+    loop = asyncio.new_event_loop()
+    asyncio.set_event_loop(loop)
+    return loop
+
+
+async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
+    """ equivalent to next(iter) for asynchronous iterators. Modifies aiter in-place! """
+    return await aiter.__anext__()
+
+
+async def aiter(*args: T) -> AsyncIterator[T]:
+    """ create an asynchronous iterator from a sequence of values """
+    for arg in args:
+        yield arg
+
+
+async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
+    """ equivalent to chain(iter1, iter2, ...) for asynchronous iterators. """
+    for aiter in async_iters:
+        async for elem in aiter:
+            yield elem

+ 29 - 4
hivemind/utils/grpc.py

@@ -2,19 +2,20 @@
 Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
 Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
 """
 """
 from __future__ import annotations
 from __future__ import annotations
+
 import os
 import os
 import threading
 import threading
-from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type
+from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable
 
 
 import grpc
 import grpc
 import numpy as np
 import numpy as np
 import torch
 import torch
+from hivemind.proto.runtime_pb2 import CompressionType
 
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
-from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithExpiration
-from hivemind.utils.networking import Endpoint
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
+from hivemind.utils.networking import Endpoint
+from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithExpiration
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -235,3 +236,27 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
 
 
     tensor.requires_grad_(serialized_tensor.requires_grad)
     tensor.requires_grad_(serialized_tensor.requires_grad)
     return tensor
     return tensor
+
+
+def split_for_streaming(serialized_tensor: runtime_pb2.Tensor, chunk_size_bytes: int) -> Iterator[runtime_pb2.Tensor]:
+    """ Split serialized_tensor into multiple chunks for gRPC streaming """
+    buffer = memoryview(serialized_tensor.buffer)
+    num_chunks = len(range(0, len(buffer), chunk_size_bytes))
+    yield runtime_pb2.Tensor(
+        compression=serialized_tensor.compression, buffer=buffer[:chunk_size_bytes].tobytes(), chunks=num_chunks,
+        size=serialized_tensor.size, dtype=serialized_tensor.dtype, requires_grad=serialized_tensor.requires_grad)
+    for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
+        yield runtime_pb2.Tensor(buffer=buffer[chunk_start: chunk_start + chunk_size_bytes].tobytes())
+
+
+def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
+    """ Restore a result of split_into_chunks into a single serialized tensor """
+    stream = iter(stream)
+    first_chunk = next(stream)
+    serialized_tensor = runtime_pb2.Tensor()
+    serialized_tensor.CopyFrom(first_chunk)
+    buffer_chunks = [first_chunk.buffer]
+    for tensor_part in stream:
+        buffer_chunks.append(tensor_part.buffer)
+    serialized_tensor.buffer = b''.join(buffer_chunks)
+    return serialized_tensor

+ 1 - 1
hivemind/utils/mpfuture.py

@@ -54,7 +54,7 @@ class MPFuture(base.Future):
                 self.connection.close()
                 self.connection.close()
         except TimeoutError as e:
         except TimeoutError as e:
             raise e
             raise e
-        except (BrokenPipeError, OSError) as e:
+        except (BrokenPipeError, OSError, EOFError) as e:
             if self._state in (base.PENDING, base.RUNNING):
             if self._state in (base.PENDING, base.RUNNING):
                 self._state, self._exception = base.FINISHED, e
                 self._state, self._exception = base.FINISHED, e
 
 

+ 4 - 0
hivemind/utils/tensor_descr.py

@@ -1,3 +1,4 @@
+import warnings
 from dataclasses import dataclass, asdict
 from dataclasses import dataclass, asdict
 
 
 import torch
 import torch
@@ -6,6 +7,9 @@ from hivemind.proto.runtime_pb2 import CompressionType
 
 
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 
 
+warnings.filterwarnings("ignore", "CUDA initialization*", category=UserWarning)
+# ^-- cures https://github.com/pytorch/pytorch/issues/47038
+
 
 
 @dataclass(init=True, repr=True, frozen=True)
 @dataclass(init=True, repr=True, frozen=True)
 class DescriptorBase:
 class DescriptorBase:

+ 5 - 0
hivemind/utils/timed_storage.py

@@ -81,6 +81,11 @@ class TimedStorage(Generic[KeyType, ValueType]):
             return top_key, self.data[top_key]
             return top_key, self.data[top_key]
         return None, None
         return None, None
 
 
+    def clear(self):
+        self.data.clear()
+        self.key_to_heap.clear()
+        self.expiration_heap.clear()
+
     def __contains__(self, key: KeyType):
     def __contains__(self, key: KeyType):
         self._remove_outdated()
         self._remove_outdated()
         return key in self.data
         return key in self.data

+ 88 - 0
tests/benchmark_averaging.py

@@ -0,0 +1,88 @@
+import time
+import threading
+import argparse
+
+import torch
+import hivemind
+from hivemind.utils import LOCALHOST, increase_file_limit
+from hivemind.proto import runtime_pb2
+
+
+def sample_tensors(hid_size, num_layers):
+    tensors = []
+    for i in range(num_layers):
+        tensors.append(torch.randn(hid_size, 3 * hid_size))
+        tensors.append(torch.randn(3 * hid_size))
+        tensors.append(torch.randn(3 * hid_size))
+        tensors.append(torch.randn(hid_size, hid_size))
+        tensors.append(torch.ones(hid_size))
+        tensors.append(torch.zeros(hid_size))
+        tensors.append(torch.randn(hid_size, 4 * hid_size))
+        tensors.append(torch.randn(4 * hid_size))
+        tensors.append(torch.ones(4 * hid_size))
+        tensors.append(torch.randn(2, hid_size, hid_size, 2))
+        tensors.append(torch.randn(hid_size))
+        tensors.append(torch.randn(hid_size))
+        tensors.append(torch.randn(hid_size))
+    return tuple(tensors)
+
+
+def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
+                        averaging_expiration: float, request_timeout: float, round_timeout: float,
+                        hid_size: int, num_layers: int, spawn_dtime: float):
+    dht_root = hivemind.DHT(listen_on=f'{LOCALHOST}:*', start=True)
+    peer_tensors = [sample_tensors(hid_size, num_layers)
+                    for _ in range(num_peers)]
+    processes = {dht_root}
+
+    def run_averager(index):
+        dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
+                           initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
+                           start=True)
+        averager = hivemind.DecentralizedAverager(
+            peer_tensors[i], dht, prefix='my_tensor', initial_group_bits='0110', listen_on=f"{LOCALHOST}:*",
+            compression_type=runtime_pb2.CompressionType.FLOAT16, target_group_size=target_group_size,
+            averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)
+        processes.update({dht, averager})
+
+        print(end=f'<started {index}>\n', flush=True)
+        for _ in range(num_rounds):
+            success = averager.step(timeout=round_timeout)
+            print(end=('+' if success else '-'), flush=True)
+        print(end=f'<finished {index}>\n', flush=True)
+
+    threads = []
+    for i in range(num_peers):
+        thread = threading.Thread(target=run_averager, args=[i])
+        threads.append(thread)
+        thread.start()
+        time.sleep(spawn_dtime)
+
+    t = time.time()
+    for thread in threads:
+        thread.join()
+
+    print(f"\ntest run took {time.time() - t:.3f} seconds")
+
+    for process in processes:
+        process.terminate()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--num_peers', type=int, default=16, required=False)
+    parser.add_argument('--target_group_size', type=int, default=4, required=False)
+    parser.add_argument('--num_rounds', type=int, default=5, required=False)
+    parser.add_argument('--hid_size', type=int, default=256, required=False)
+    parser.add_argument('--num_layers', type=int, default=3, required=False)
+    parser.add_argument('--averaging_expiration', type=float, default=15, required=False)
+    parser.add_argument('--round_timeout', type=float, default=30, required=False)
+    parser.add_argument('--request_timeout', type=float, default=3, required=False)
+    parser.add_argument('--spawn_dtime', type=float, default=0.1, required=False)
+    parser.add_argument('--increase_file_limit', action="store_true")
+    args = vars(parser.parse_args())
+
+    if args.pop('increase_file_limit', False):
+        increase_file_limit()
+
+    benchmark_averaging(**args)

+ 10 - 9
tests/test_averaging.py

@@ -34,8 +34,7 @@ def test_getset_averagers():
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-@pytest.mark.asyncio
-async def test_allreduce_once():
+def test_allreduce_once():
     dht = hivemind.DHT(start=True)
     dht = hivemind.DHT(start=True)
 
 
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors1 = [torch.randn(123), torch.zeros(3)]
@@ -52,12 +51,14 @@ async def test_allreduce_once():
 
 
     futures = []
     futures = []
     for averager in averagers:
     for averager in averagers:
-        futures.append(averager.step(return_future=True))  # TODO revert to hard version
-        time.sleep(0.5)
-
+        futures.append(averager.step(wait=False))
     for future in futures:
     for future in futures:
-        for ref, our in zip(reference, future.result()):
-            assert torch.allclose(ref, our)
+        assert future.result() is True
+
+    for averager in averagers:
+        with averager.get_tensors() as averaged_tensors:
+            for ref, our in zip(reference, averaged_tensors):
+                assert torch.allclose(ref, our, atol=1e-6)
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -90,7 +91,7 @@ async def test_allreduce_protocol():
     ]
     ]
 
 
     for peer, allreduce in zip(peers, allreduce_protocols):
     for peer, allreduce in zip(peers, allreduce_protocols):
-        assert allreduce.averaged_tensors.done()
+        assert allreduce.future.done()
         averaged_tensors = await allreduce
         averaged_tensors = await allreduce
         assert len(averaged_tensors) == len(reference_tensors)
         assert len(averaged_tensors) == len(reference_tensors)
         assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
         assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
@@ -98,7 +99,7 @@ async def test_allreduce_protocol():
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_chunks():
+def test_partitioning():
     for _ in range(100):
     for _ in range(100):
         tensors = []
         tensors = []
         for _ in range(random.randint(1, 5)):
         for _ in range(random.randint(1, 5)):

+ 25 - 0
tests/test_util_modules.py

@@ -166,3 +166,28 @@ async def test_channel_cache():
         for j in range(i + 1, len(all_channels)):
         for j in range(i + 1, len(all_channels)):
             ci, cj = all_channels[i], all_channels[j]
             ci, cj = all_channels[i], all_channels[j]
             assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
             assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
+
+
+def test_serialize_tensor():
+    tensor = torch.randn(512, 12288)
+
+    serialized_tensor = hivemind.serialize_torch_tensor(tensor, hivemind.CompressionType.NONE)
+    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
+        chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
+        assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
+        restored = hivemind.combine_from_streaming(chunks)
+        assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)
+
+    chunk_size = 30 * 1024
+    serialized_tensor = hivemind.serialize_torch_tensor(tensor, hivemind.CompressionType.FLOAT16)
+    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
+    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
+    restored = hivemind.combine_from_streaming(chunks)
+    assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor, rtol=0, atol=1e-2)
+
+    tensor = torch.randint(0, 100, (512, 1, 1))
+    serialized_tensor = hivemind.serialize_torch_tensor(tensor, hivemind.CompressionType.NONE)
+    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
+    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
+    restored = hivemind.combine_from_streaming(chunks)
+    assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)