Explorar o código

Address averaging corner cases, add benchmark_averaging.py, chunk averaged tensors, fix DHTNode get (#134)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic %!s(int64=4) %!d(string=hai) anos
pai
achega
e159605143

+ 0 - 6
.circleci/config.yml

@@ -21,12 +21,6 @@ jobs:
       - run:
           command: pytest ./tests
           name: tests
-      - run:
-          command: python tests/benchmark_throughput.py --preset minimalistic
-          name: benchmark_throughput
-      - run:
-          command: python tests/benchmark_dht.py
-          name: benchmark_dht
       - run:
           command: codecov
           name: codecov

+ 114 - 55
hivemind/client/averaging/__init__.py

@@ -2,22 +2,23 @@
 
 from __future__ import annotations
 
-import random
+import asyncio
+import contextlib
 import ctypes
-from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
-from concurrent.futures.thread import ThreadPoolExecutor
 import multiprocessing as mp
-import asyncio
+import random
+from concurrent.futures.thread import ThreadPoolExecutor
+from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 
-import torch
-import uvloop
 import grpc
+import torch
 
 import hivemind
 from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
 from hivemind.client.averaging.matchmaking import Matchmaking
-from hivemind.utils import get_logger, Endpoint, Port, MPFuture, replace_port, GRPC_KEEPALIVE_OPTIONS
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
+from hivemind.utils import get_logger, Endpoint, Port, MPFuture, replace_port, GRPC_KEEPALIVE_OPTIONS, get_dht_time
+from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 
 # flavour types
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
@@ -46,7 +47,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
       note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
     :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
     :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
-
+    :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
+    :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
+    :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
     :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
             if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
@@ -62,13 +65,14 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     _pending_group_assembled: asyncio.Event
 
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
-                 prefix: str, target_group_size: int, min_group_size: int = 1, initial_group_bits: Optional[str] = None,
+                 prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
                  averaging_expiration: float = 15, allreduce_timeout: Optional[float] = None,
+                 request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
-                 listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1,
+                 listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
-        assert '.' not in prefix, "group prefix must be a string without ."
-        if is_power_of_two(target_group_size):
+        assert '.' not in prefix, "group prefix must be a string without trailing '.'"
+        if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
         if initial_group_bits is None:
             initial_group_bits = ''.join(random.choices('01', k=INITIAL_GROUP_NBITS))
@@ -79,16 +83,19 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self.dht = dht
         self.listen_on, self.receiver_threads, self.kwargs = listen_on, receiver_threads, kwargs
         self.channel_options = channel_options
-        self.averaged_tensors = tuple(averaged_tensors)
-        # TODO use mp.Lock to prevent someone from modifying tensors before we copy them! maybe.
-        for tensor in self.averaged_tensors:
+        self.daemon = daemon
+
+        self._averaged_tensors = tuple(averaged_tensors)
+        self.lock_averaged_tensors = mp.Lock()
+        for tensor in self._averaged_tensors:
             assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
             tensor.share_memory_()
 
-        self.matchmaking_kwargs = dict(prefix=prefix, initial_group_bits=initial_group_bits,
-                                       target_group_size=target_group_size, min_group_size=min_group_size,
-                                       averaging_expiration=averaging_expiration)
-        self.allreduce_timeout, self.compression_type = allreduce_timeout, compression_type
+        self.matchmaking_kwargs = dict(
+            prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
+            min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout,
+            chunk_size_bytes=chunk_size_bytes, compression_type=compression_type)
+        self.allreduce_timeout = allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
@@ -115,13 +122,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     def run(self):
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
-        if asyncio.get_event_loop().is_running():
-            asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
-
-        uvloop.install()
-        loop = asyncio.new_event_loop()
-        asyncio.set_event_loop(loop)
-
+        loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
         pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
 
@@ -132,7 +133,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             found_port = server.add_insecure_port(self.listen_on)
             assert found_port != 0, f"Failed to listen to {self.listen_on}"
             self._port.value = found_port
-            self._matchmaking = Matchmaking(self.endpoint, self.averaged_tensors, self.dht, **self.matchmaking_kwargs)
+            self._matchmaking = Matchmaking(self.endpoint, self._averaged_tensors, self.dht, **self.matchmaking_kwargs,
+                                            return_deltas=True)  # note: we need deltas to make allreduce lock-free
             self._pending_group_assembled = asyncio.Event()
             self._pending_group_assembled.set()
             await server.start()
@@ -161,37 +163,88 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         else:
             logger.warning("DHT shutdown has no effect: the process is not alive")
 
-    def step(self, timeout: Optional[float] = None, return_future=False) -> Union[Sequence[torch.Tensor], MPFuture]:
+    def step(self, allow_retries: bool = True, timeout: Optional[float] = None, wait=True
+             ) -> Union[bool, MPFuture]:
         """
-        Set up the averager to look for a group and run one round of averaging, then return the averaged tensors
-
+        Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
+        :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
+          within the specified timeout
         :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
         """
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_step', [], dict(future=_future, timeout=timeout)))
-        return future if return_future else future.result()
+        self.pipe.send(('_step', [], dict(future=_future, allow_retries=allow_retries, timeout=timeout)))
+        return future.result() if wait else future
+
+    async def _step(self, *, future: MPFuture, allow_retries: bool, timeout: Optional[float]):
+        loop = asyncio.get_event_loop()
+        start_time = get_dht_time()
 
-    async def _step(self, *, future: MPFuture, timeout: Optional[float]):
+        try_averaging = True
         group_id = None
-        try:
-            self._pending_group_assembled.clear()
-            allreduce_group = await self._matchmaking.look_for_group(timeout=timeout)
-            group_id = allreduce_group.group_id
-            if allreduce_group is not None:
+
+        while try_averaging:
+            try:
+                self._pending_group_assembled.clear()
+                allreduce_group = await self._matchmaking.look_for_group(timeout=timeout)
+                if allreduce_group is None:
+                    raise AllreduceException("Averaging step failed: could not find a group.")
+
+                group_id = allreduce_group.group_id
                 self._running_groups[group_id] = allreduce_group
                 self._pending_group_assembled.set()
-                future.set_result(await asyncio.wait_for(allreduce_group.run(), self.allreduce_timeout))
-            else:
-                raise AllreduceException(f"{self} - group_allreduce failed, unable to find a group")
-
-        except Exception as e:
-            future.set_exception(e)
-            raise
-        finally:
-            self._pending_group_assembled.set()
-            if group_id is not None:
+                averaging_deltas = await asyncio.wait_for(allreduce_group.run(), self.allreduce_timeout)
+                update_ok = await loop.run_in_executor(None, lambda: self.update_tensors(averaging_deltas, add=True))
+
+                # averaging is finished, exit the loop
+                future.set_result(update_ok)
+                try_averaging = False
+
+            except AllreduceException:
+                time_elapsed = get_dht_time() - start_time
+                if not allow_retries or (timeout is not None and timeout < time_elapsed):
+                    future.set_result(False)
+                    try_averaging = False
+
+            except Exception as e:
+                future.set_exception(e)
+                raise
+            finally:
                 _ = self._running_groups.pop(group_id, None)
+                self._pending_group_assembled.set()
+
+    def update_tensors(self, tensors: Sequence[torch.Tensor], *, add: bool = False) -> bool:
+        """
+        Set or change the values of self.averaged_tensors.
+
+        :param tensors: list/tuple of tensors of same shape as self.averaged_tensors
+        :param add: if True, add tensors to self.averaged_tensors in-place
+          by default, simply write the values of :tensors: to self.averaged_tensors
+        :note: if there may be updates running in background, it is recommended to use add=True
+        """
+        assert len(tensors) == len(self._averaged_tensors)
+        with torch.no_grad(), self.lock_averaged_tensors:
+            for tensor, update in zip(self._averaged_tensors, tensors):
+                if add:
+                    tensor += update
+                else:
+                    tensor[...] = update
+        return True
+
+    @contextlib.contextmanager
+    def get_tensors(self) -> Sequence[torch.Tensor]:
+        """
+        A contextmanager that gives user access to averaged tensors.
+        It is guaranteed that the averager will not modify tensors while this context is active.
+
+        Example:
+              >>> with averager.get_tensors() as tensors:
+              >>>     update_model(tensors)
+              >>>     tensors[0] += 1
+              >>> # do not use tensors after the lock is acquired
+        """
+        with self.lock_averaged_tensors:
+            yield self._averaged_tensors
 
     async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
                              ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
@@ -199,16 +252,22 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         async for response in self._matchmaking.rpc_join_group(request, context):
             yield response
 
-    async def rpc_aggregate_part(self, request: averaging_pb2.AveragingData, context: grpc.ServicerContext):
+    async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
+                                 ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
-        if request.group_id not in self._running_groups and not self._pending_group_assembled.is_set():
+        request = await anext(stream)
+        if request.group_id not in self._running_groups:
             # this handles a special case when leader accepted us to group AND began allreduce right away,
             # but his response with group_id was delayed and other peers got to us first
             await self._pending_group_assembled.wait()
-        if request.group_id not in self._running_groups:
-            return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
-        else:
-            return await self._running_groups[request.group_id].rpc_aggregate_part(request, context)
+
+        group = self._running_groups.get(request.group_id)
+        if group is None:
+            yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+            return
+
+        async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
+            yield message
 
 
 def is_power_of_two(n):

+ 89 - 40
hivemind/client/averaging/allreduce.py

@@ -1,10 +1,11 @@
 import asyncio
-from typing import Sequence, Set, Dict, Tuple
+from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Iterator
 
 import grpc
 import torch
 
-from hivemind.utils import Endpoint, get_logger, serialize_torch_tensor, deserialize_torch_tensor, ChannelCache
+from hivemind.utils import Endpoint, get_logger, ChannelCache, anext
+from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
 from hivemind.proto import averaging_pb2_grpc, runtime_pb2, averaging_pb2
 
 # flavour types
@@ -19,25 +20,32 @@ class AllReduceProtocol:
     :param tensors: local tensors that should be averaged with groupmates
     :param endpoint: your endpoint, must be included in ordered_group_endpoints
     :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
+    :param return_deltas: if True, returns the element-wise differences (averaged_tensors - original_tensors)
+           default (False) - return averaged_tensors by themselves
     """
+
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint]):
+                 ordered_group_endpoints: Sequence[Endpoint], return_deltas: bool = False):
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
         self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, self.group_size)))
         self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
+        self.return_deltas = return_deltas
 
         self.accumulator = self.local_tensor_parts[self.endpoint].clone()  # sum inputs from peers to this tensor
         self.accumulated_from: Set[Endpoint] = {self.endpoint}  # peers that we have accumulated our part from
         self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
         self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
-        self.averaged_tensors: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
+        self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
 
     def __repr__(self):
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
 
     def __await__(self):
-        return self.averaged_tensors.__await__()
+        return self.future.__await__()
+
+    def __contains__(self, endpoint: Endpoint):
+        return endpoint in self.local_tensor_parts
 
     @property
     def group_size(self):
@@ -46,7 +54,7 @@ class AllReduceProtocol:
     async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor) -> torch.Tensor:
         """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
         assert not self.averaged_part.done(), f"already finished averaging part: {self.averaged_part}"
-        assert not self.averaged_tensors.done(), f"already finished allreduce: {self.averaged_tensors}"
+        assert not self.future.done(), f"already finished allreduce: {self.future}"
         assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
         assert source not in self.accumulated_from, "duplicate source, already received that part"
         logger.debug(f"{self} - accumulating tensor part from {source}")
@@ -63,7 +71,7 @@ class AllReduceProtocol:
         return await self.averaged_part
 
     def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
-        assert not self.averaged_tensors.done(), f"already finished allreduce: {self.averaged_tensors}"
+        assert not self.future.done(), f"already finished allreduce: {self.future}"
         assert source in self.local_tensor_parts, "the provider of averaged part is not from my group"
         assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
         assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
@@ -72,28 +80,37 @@ class AllReduceProtocol:
         self.averaged_tensor_parts[source] = averaged_part
         if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
             ordered_averaged_parts = [self.averaged_tensor_parts[endpoint] for endpoint in self.ordered_group_endpoints]
-            self.averaged_tensors.set_result(restore_from_parts(ordered_averaged_parts, self.tensor_shapes))
+            outputs = restore_from_parts(ordered_averaged_parts, self.tensor_shapes)
+
+            if self.return_deltas:
+                local_parts = [self.local_tensor_parts[peer] for peer in self.ordered_group_endpoints]
+                with torch.no_grad():
+                    original_tensors = restore_from_parts(local_parts, self.tensor_shapes)
+                    for averaged_tensor, original_tensor in zip(outputs, original_tensors):
+                        averaged_tensor -= original_tensor
+
+            self.future.set_result(outputs)
 
     def cancel(self) -> bool:
-        if not self.averaged_tensors.done():
+        if not self.future.done():
             logger.debug(f"{self} - cancelled")
-            self.averaged_tensors.cancel()
+            self.future.cancel()
             if not self.averaged_part.done():
                 self.averaged_part.cancel()
             return True
         else:
-            logger.debug(f"{self} - failed to cancel, allreduce is already finished: {self.averaged_tensors}")
+            logger.debug(f"{self} - failed to cancel, allreduce is already finished: {self.future}")
             return False
 
     def set_exception(self, exception: Exception) -> bool:
-        if not self.averaged_tensors.done():
+        if not self.future.done():
             logger.debug(f"{self} - {exception}")
-            self.averaged_tensors.set_exception(exception)
+            self.future.set_exception(exception)
             if not self.averaged_part.done():
                 self.averaged_part.cancel()
             return True
         else:
-            logger.debug(f"{self} - failed to set {exception}, allreduce already finished: {self.averaged_tensors}")
+            logger.debug(f"{self} - failed to set {exception}, allreduce already finished: {self.future}")
             return False
 
 
@@ -101,11 +118,14 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
     """
     A class that implements ButterflyAllReduceProtocol on top of a gRPC servicer
     """
+
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType):
+                 ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
+                 chunk_size_bytes: int, return_deltas: bool = False):
         super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint,
-                         ordered_group_endpoints=ordered_group_endpoints)
-        self.compression_type = compression_type
+                         ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
+        self.compression_type, self.chunk_size_bytes = compression_type, chunk_size_bytes
+        self.averaged_part_stream: asyncio.Future[Tuple[runtime_pb2.Tensor, ...]] = asyncio.Future()
 
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
@@ -113,55 +133,84 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
     async def _average_one_part(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
         """ Send one part of local tensors to one groupmate and collect the average for this part """
         serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
-        response = await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(
-            averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id,
-                                        endpoint=self.endpoint, tensor_part=serialized_tensor_part))
-        if response.code == averaging_pb2.AVERAGED_PART:
-            averaged_part = deserialize_torch_tensor(response.tensor_part)
-            self.register_averaged_part(peer_endpoint, averaged_part)
-            return averaged_part
-        else:
-            raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(response.code)}"
+        chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes)
+
+        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
+        await stream.write(averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id,
+                                                       endpoint=self.endpoint, tensor_part=next(chunks)))
+        for chunk in chunks:
+            await stream.write(averaging_pb2.AveragingData(tensor_part=chunk))
+        await stream.done_writing()
+
+        outputs: Sequence[averaging_pb2.AveragingData] = [message async for message in stream]
+        code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR
+        if code != averaging_pb2.AVERAGED_PART:
+            raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}"
                                      f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
                                      f" allreduce failed")
 
+        averaged_part = deserialize_torch_tensor(combine_from_streaming([message.tensor_part for message in outputs]))
+        self.register_averaged_part(peer_endpoint, averaged_part)
+        return averaged_part
+
     async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
-        await self._get_peer_stub(peer_endpoint).rpc_aggregate_part(averaging_pb2.AveragingData(
-            group_id=self.group_id, endpoint=self.endpoint, code=code))
+        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
+        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
+        await stream.done_writing()
 
     async def run(self) -> Sequence[torch.Tensor]:
-        """ send allreduce requests to all peers and collect results, return the averaged tensor """
+        """
+        send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
+        """
         try:
             await asyncio.gather(self, *(self._average_one_part(peer, part)
                                          for peer, part in self.local_tensor_parts.items() if peer != self.endpoint))
             return await self
-        except Exception as e:
+        except BaseException as e:
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
             self.set_exception(e)
             for peer_endpoint in self.ordered_group_endpoints:
-                asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
+                if peer_endpoint != self.endpoint:
+                    asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
             raise
 
-    async def rpc_aggregate_part(self, request: averaging_pb2.AveragingData, context: grpc.ServicerContext):
+    async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
+                                        ) -> Iterable[runtime_pb2.Tensor]:
+        """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
+        tensor_part: torch.Tensor = deserialize_torch_tensor(combine_from_streaming(stream_messages))
+        averaged_part = await self.accumulate_part(source, tensor_part)
+        if not self.averaged_part_stream.done():
+            serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False)
+            stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
+            self.averaged_part_stream.set_result(stream_chunks)
+            return stream_chunks
+        else:
+            return self.averaged_part_stream.result()
+
+    async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
+                                 ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
+        request: averaging_pb2.AveragingData = await anext(stream)
+
         if request.group_id != self.group_id:
-            return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+            yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
 
-        if request.code == averaging_pb2.PART_FOR_AVERAGING:
+        elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
-                tensor_part = deserialize_torch_tensor(request.tensor_part)
-                averaged_part = await self.accumulate_part(request.endpoint, tensor_part)
-                serialized = serialize_torch_tensor(averaged_part, request.tensor_part.compression, allow_inplace=False)
-                return averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized)
+                tensor_chunks = (request.tensor_part, *[msg.tensor_part async for msg in stream])
+                averaged_chunks = iter(await self.accumulate_part_streaming(request.endpoint, tensor_chunks))
+                yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=next(averaged_chunks))
+                for averaged_chunk in averaged_chunks:
+                    yield averaging_pb2.AveragingData(tensor_part=averaged_chunk)
             except Exception as e:
                 self.set_exception(e)
-                return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+                yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
         else:
             error_code = averaging_pb2.MessageCode.Name(request.code)
             logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
             self.set_exception(AllreduceException(f"peer {request.endpoint} sent {error_code}."))
-            return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+            yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
 
 
 def split_into_parts(tensors: Sequence[torch.Tensor], group_size: int) -> Tuple[torch.Tensor, ...]:

+ 196 - 138
hivemind/client/averaging/matchmaking.py

@@ -6,7 +6,7 @@ import contextlib
 import random
 from dataclasses import asdict
 from math import isfinite
-from typing import Sequence, Optional, AsyncIterator, Set
+from typing import Sequence, Optional, AsyncIterator, Set, Tuple
 import asyncio
 
 import torch
@@ -27,29 +27,41 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
     f"""
     An internal class that is used to form groups of averages for running allreduce
     See DecentralizedAverager docstring for the detailed description of all parameters
+    
+    :note: on implementation: the current matchmaker protocol can encounter one type of (temporary) deadlock;
+      This deadlock occurs when averager A requests averager B at the same time as averager B requests averager A.
+      In that case, neither averager can process the other one's request because it is awaiting lock_request_join_group.
+      This deadlock only happens if averagers have outdated information on expirations (due to network delays). 
+      While A->B->A deadlock is easy to fix, it gets much harder with more peers (e.g. A -> B -> C -> D -> A).
+      Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
+    
     """
 
     def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *,
                  prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15, compression_type: runtime_pb2.CompressionType = runtime_pb2.NONE):
+                 averaging_expiration: float = 15, request_timeout: float, **allreduce_kwargs):
         assert '.' not in prefix, "group prefix must be a string without ."
+        if request_timeout is None or request_timeout >= averaging_expiration:
+            logger.warning("It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
+                           "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring.")
 
         super().__init__()
         self.dht, self.endpoint, self.averaged_tensors = dht, endpoint, tuple(averaged_tensors)
         self.prefix, self.group_bits = prefix, initial_group_bits
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
-        self.averaging_expiration, self.compression_type = averaging_expiration, compression_type
-
+        self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
+        self.allreduce_kwargs = allreduce_kwargs
         self.schema_hash = compute_schema_hash(self.averaged_tensors)
 
         self.lock_looking_for_group = asyncio.Lock()
         self.lock_request_join_group = asyncio.Lock()
-        self.cond_notify_followers = asyncio.Condition()
+        self.follower_was_discarded = asyncio.Event()
+        self.was_accepted_to_group = asyncio.Event()
         self.assembled_group = asyncio.Future()
 
         self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Set[Endpoint] = set()  # iff i am a leader, this contains my followers excluding myself
-        self.potential_leaders = PotentialLeaders(self.endpoint, self.dht, self.averaging_expiration)
+        self.potential_leaders = PotentialLeaders(endpoint, dht, averaging_expiration, target_group_size)
 
     @property
     def is_looking_for_group(self):
@@ -70,7 +82,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
                f" current key = {self.current_group_key})"
 
-    async def look_for_group(self, *, timeout: Optional[float] = None) -> AllReduceRunner:
+    async def look_for_group(self, *, timeout: Optional[float] = None) -> Optional[AllReduceRunner]:
         """
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
@@ -82,48 +94,58 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
             try:
                 return await asyncio.wait_for(self.assembled_group, timeout=timeout)
-            except Exception as e:
+            except asyncio.TimeoutError:
+                return None
+
+            except BaseException as e:
                 if len(self.current_followers) > 0:
                     async with self.lock_request_join_group:
                         await self.leader_disband_group()
-                self.assembled_group.set_exception(e)
+                if not self.assembled_group.done():
+                    self.assembled_group.set_exception(e)
                 raise
 
             finally:
                 if not request_leaders_task.done():
                     request_leaders_task.cancel()
-                if self.assembled_group.done():
-                    self.assembled_group = asyncio.Future()
+                if not self.assembled_group.done():
+                    self.assembled_group.cancel()
+                while len(self.current_followers) > 0:
+                    await self.follower_was_discarded.wait()
+                    self.follower_was_discarded.clear()
+                # note: the code above ensures that we send all followers away before creating new future
+                self.assembled_group = asyncio.Future()
+                self.was_accepted_to_group.clear()
 
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
         """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
-        end_time = get_dht_time() + timeout if timeout is not None else float('inf')
         async with self.potential_leaders.begin_search(self.current_group_key, timeout):
             # TODO update group_bits on success! reduce number of bits on not enough peers.
             # TODO after allreduce finishes, we may need to ask leader to notify lower keys about this
             # (so as to fix possible network partitioning if some peers operate on a much smaller nbits)
             while True:
                 try:
-                    time_to_expiration = self.potential_leaders.declared_expiration_time - get_dht_time()
-                    next_best_leader = await asyncio.wait_for(
-                        self.potential_leaders.pop_next_leader(),
-                        timeout=time_to_expiration if isfinite(time_to_expiration) else None)
-
-                    request_expiration_time = min(self.potential_leaders.declared_expiration_time,
-                                                  end_time, get_dht_time() + self.averaging_expiration)
-                    group = await self.request_join_group(next_best_leader, request_expiration_time)
+                    next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
+
+                    group = await self.request_join_group(next_leader, self.potential_leaders.request_expiration_time)
                     if group is not None:
                         return group
 
                 except asyncio.TimeoutError:
                     async with self.lock_request_join_group:
-                        if len(self.current_followers) >= self.min_group_size:
+                        if self.assembled_group.done():
+                            return self.assembled_group.result()
+                        elif len(self.current_followers) + 1 >= self.min_group_size:
                             # the time is up, we have a *good enough* group. run allreduce as is.
                             return await self.leader_assemble_group()
-                        else:
+                        elif len(self.current_followers) > 0:
                             await self.leader_disband_group()
                             # TODO maybe adjust grid size
-                            continue
+                        continue
+                except Exception as e:
+                    if not self.assembled_group.done():
+                        self.assembled_group.set_exception(e)
+                    raise e
 
     async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[AllReduceRunner]:
         """
@@ -134,87 +156,101 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
           The originally specified leader can disband group and redirect us to a different leader
         """
         assert self.is_looking_for_group and self.current_leader is None
-        call: Optional[grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]] = None
+        call: Optional[grpc.aio.UnaryStreamCall] = None
         try:
             async with self.lock_request_join_group:
                 leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
                 call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
                     endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time))
+                message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
 
-                message = await call.read()
-                if message.code != averaging_pb2.ACCEPTED:
-                    code = averaging_pb2.MessageCode.Name(message.code)
-                    logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
-                    return None
+                if message.code == averaging_pb2.ACCEPTED:
+                    logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
+                    self.current_leader = leader
+                    self.was_accepted_to_group.set()
+                    if len(self.current_followers) > 0:
+                        await self.leader_disband_group()
 
-                # else: we were accepted
-                logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
-                self.current_leader = leader
-                if len(self.current_followers) > 0:
-                    await self.leader_disband_group()
+            if message.code != averaging_pb2.ACCEPTED:
+                code = averaging_pb2.MessageCode.Name(message.code)
+                logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
+                return None
 
             async with self.potential_leaders.pause_search():
-                message = await call.read()
+                time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
+                message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout)
 
-            if message.code == averaging_pb2.BEGIN_ALLREDUCE:
-                async with self.lock_request_join_group:
-                    return await self.follower_assemble_group(leader, message.group_id, message.ordered_group_endpoints)
-            elif message.code == averaging_pb2.GROUP_DISBANDED and bool(message.suggested_leader):
-                logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
-                return await self.request_join_group(message.suggested_leader, expiration_time)
+                if message.code == averaging_pb2.BEGIN_ALLREDUCE:
+                    async with self.lock_request_join_group:
+                        return await self.follower_assemble_group(
+                            leader, message.group_id, message.ordered_group_endpoints)
+
+            if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
+                if message.suggested_leader and message.suggested_leader != self.endpoint:
+                    logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
+                    self.current_leader = None
+                    call.cancel()
+                    return await self.request_join_group(message.suggested_leader, expiration_time)
+                else:
+                    logger.debug(f"{self} - leader disbanded group")
+                    return None
 
-            else:
-                logger.debug(f"{self} - leader sent {averaging_pb2.MessageCode.Name(message.code)}, leaving group")
-                return None
+            logger.debug(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}")
+            return None
+        except asyncio.TimeoutError:
+            logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
+            if call is not None:
+                call.cancel()
+            return None
         finally:
+            self.was_accepted_to_group.clear()
             self.current_leader = None
             if call is not None:
-                call.cancel()
+                await call.code()
 
     async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
                              ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """
         try:
-            reason_to_reject = self._check_reasons_to_reject(request)
-            if reason_to_reject is not None:
-                yield reason_to_reject
-                return
-
-            current_group = self.assembled_group  # copy current assembled_group to avoid overwriting
             async with self.lock_request_join_group:
+                reason_to_reject = self._check_reasons_to_reject(request)
+                if reason_to_reject is not None:
+                    yield reason_to_reject
+                    return
+
                 self.current_followers.add(request.endpoint)
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
-                if len(self.current_followers) + 1 >= self.target_group_size:
+                if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
                     # outcome 1: we have assembled a full group and are ready for allreduce
                     await self.leader_assemble_group()
 
-            if not current_group.done():
-                try:
-                    async with self.cond_notify_followers:
-                        # wait for the group to be assembled or disbanded
-                        timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time())
-                        await asyncio.wait_for(self.cond_notify_followers.wait(), timeout=timeout)
-                except asyncio.TimeoutError:
-                    async with self.lock_request_join_group:
+            # wait for the group to be assembled or disbanded
+            timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time())
+            await asyncio.wait({self.assembled_group, self.was_accepted_to_group.wait()},
+                               return_when=asyncio.FIRST_COMPLETED, timeout=timeout)
+            if not self.assembled_group.done() and not self.was_accepted_to_group.is_set():
+                async with self.lock_request_join_group:
+                    if self.assembled_group.done():
+                        pass  # this covers a rare case when the group is assembled while the event loop was busy.
+                    elif len(self.current_followers) + 1 >= self.min_group_size and self.is_looking_for_group:
                         # outcome 2: the time is up, run allreduce with what we have or disband
-                        if len(self.current_followers) + 1 >= self.min_group_size and self.is_looking_for_group:
-                            await self.leader_assemble_group()
-                        else:
-                            await self.leader_disband_group()
-
-            if self.current_leader is not None:
-                # outcome 3: found by a leader with higher priority, send our followers to him
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED,
-                                                      suggested_leader=self.current_leader)
-                return
+                        await self.leader_assemble_group()
+                    else:
+                        await self.leader_disband_group()
 
-            if request.endpoint not in self.current_followers:
-                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
-                return
+            if self.was_accepted_to_group.is_set() or not self.assembled_group.done() \
+                    or self.assembled_group.cancelled() or request.endpoint not in self.assembled_group.result():
+                if self.current_leader is not None:
+                    # outcome 3: found by a leader with higher priority, send our followers to him
+                    yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED,
+                                                          suggested_leader=self.current_leader)
+                    return
+                else:
+                    yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
+                    return
 
-            # finally, run allreduce
-            allreduce_group = current_group.result()
+            allreduce_group = self.assembled_group.result()
             yield averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
                 ordered_group_endpoints=allreduce_group.ordered_group_endpoints)
@@ -225,10 +261,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
             self.current_followers.discard(request.endpoint)
+            self.follower_was_discarded.set()
 
-    def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> averaging_pb2.MessageFromLeader:
+    def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> Optional[averaging_pb2.MessageFromLeader]:
         """ :returns: if accepted, return None, otherwise return a reason for rejection """
-        if not self.is_looking_for_group:
+        if not self.is_looking_for_group or self.assembled_group.done():
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
 
         if request.ListFields() == 3 and not isinstance(request.schema_hash, bytes) or len(request.schema_hash) == 0 \
@@ -243,8 +280,8 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         elif self.potential_leaders.declared_expiration_time > (request.expiration or float('inf')):
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
         elif self.current_leader is not None:
-            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER,
-                                                   suggested_leader=self.current_leader)
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader
+                                                   )  # note: this suggested leader is currently ignored
         elif request.endpoint == self.endpoint or request.endpoint in self.current_followers:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
         elif len(self.current_followers) + 1 >= self.target_group_size:
@@ -255,68 +292,71 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
     async def leader_assemble_group(self) -> AllReduceRunner:
         """ Form up all current followers into a group and prepare to _run_allreduce """
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
+        assert not self.assembled_group.done()
         group_id = DHTID.generate().to_bytes()
         ordered_group_endpoints = list(self.current_followers)
         ordered_group_endpoints.append(self.endpoint)
         random.shuffle(ordered_group_endpoints)
-        logger.debug(f"{self.endpoint} - leader started allreduce with {len(ordered_group_endpoints)} followers.")
-        allreduce_group = AllReduceRunner(
-            group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
-            ordered_group_endpoints=ordered_group_endpoints, compression_type=self.compression_type)
+        logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
+        allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
+                                          ordered_group_endpoints=ordered_group_endpoints, **self.allreduce_kwargs)
         self.assembled_group.set_result(allreduce_group)
-        async with self.cond_notify_followers:
-            self.cond_notify_followers.notify_all()
         return allreduce_group
 
     async def follower_assemble_group(self, leader: Endpoint, group_id: GroupID,
                                       ordered_group_endpoints: Sequence[Endpoint]) -> AllReduceRunner:
         """ Prepare to run allreduce using a list of peers provided by our leader """
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
+        assert not self.assembled_group.done()
         logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
         assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
-        allreduce_group = AllReduceRunner(
-            group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
-            ordered_group_endpoints=ordered_group_endpoints, compression_type=self.compression_type)
+        allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
+                                          ordered_group_endpoints=ordered_group_endpoints, **self.allreduce_kwargs)
         self.assembled_group.set_result(allreduce_group)
-        async with self.cond_notify_followers:
-            self.cond_notify_followers.notify_all()
         return allreduce_group
 
     async def leader_disband_group(self):
         """ Kick out all followers immediately, optionally direct them to our new leader (if we found one) """
         assert self.lock_request_join_group.locked()
         self.current_followers.clear()  # this will cause rpc_join_group to kick all followers out
-        async with self.cond_notify_followers:
-            self.cond_notify_followers.notify_all()
 
 
 class PotentialLeaders:
     """ An utility class that searches for averagers that could become our leaders """
-    def __init__(self, endpoint: Endpoint, dht: hivemind.DHT, averaging_expiration: DHTExpiration):
+
+    def __init__(self, endpoint: Endpoint, dht: hivemind.DHT, averaging_expiration: DHTExpiration,
+                 target_group_size: Optional[int]):
         self.endpoint, self.dht, self.averaging_expiration = endpoint, dht, averaging_expiration
+        self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
+        self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
         self.leader_queue = TimedStorage[Endpoint, DHTExpiration]()
-        self.max_assured_time = float('-inf')
+        self.past_attempts: Set[Tuple[Endpoint, DHTExpiration]] = set()
         self.declared_expiration_time = float('inf')
         self.declared_group_key: Optional[GroupKey] = None
+        self.max_assured_time = float('-inf')
         self.search_end_time = float('inf')
 
     @contextlib.asynccontextmanager
     async def begin_search(self, group_key: GroupKey, timeout: Optional[float]):
-        assert not self.running.is_set(), "already running"
-        self.running.set()
-        self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
-        update_queue_task = asyncio.create_task(self._update_queue_periodically(group_key))
-        declare_averager_task = asyncio.create_task(self._declare_averager_periodically(group_key))
-        try:
-            yield self
-        finally:
-            update_queue_task.cancel()
-            declare_averager_task.cancel()
-            self.running.clear()
-            self.update_triggered.clear()
-            self.update_finished.clear()
+        async with self.lock_search:
+            self.running.set()
+            self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
+            update_queue_task = asyncio.create_task(self._update_queue_periodically(group_key))
+            declare_averager_task = asyncio.create_task(self._declare_averager_periodically(group_key))
+            try:
+                yield self
+            finally:
+                if not update_queue_task.done():
+                    update_queue_task.cancel()
+                if not declare_averager_task.done():
+                    declare_averager_task.cancel()
+                for field in (self.past_attempts, self.leader_queue, self.running,
+                              self.update_finished, self.update_triggered, self.declared_expiration):
+                    field.clear()
+                self.max_assured_time = float('-inf')
+                self.search_end_time = float('inf')
 
     @contextlib.asynccontextmanager
     async def pause_search(self):
@@ -332,19 +372,34 @@ class PotentialLeaders:
 
     async def pop_next_leader(self) -> Endpoint:
         """ Remove and return the next most suitable leader or throw an exception if reached timeout """
-        assert self.running, "Not running search at the moment"
-        maybe_next_leader, entry = self.leader_queue.top()
-
-        next_entry_time = entry.expiration_time if maybe_next_leader is not None else get_dht_time()
-        if self.max_assured_time < next_entry_time < self.search_end_time:
-            self.update_triggered.set()
+        assert self.running.is_set(), "Not running search at the moment"
+        while True:
+            maybe_next_leader, entry = self.leader_queue.top()
+
+            if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
+                self.update_triggered.set()
+
+            if maybe_next_leader is None or entry.expiration_time >= self.declared_expiration_time:
+                await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()},
+                                   return_when=asyncio.FIRST_COMPLETED)
+                self.declared_expiration.clear()
+                if self.update_finished.is_set():
+                    self.update_finished.clear()
+                    continue
+                else:
+                    raise asyncio.TimeoutError("pop_next_leader was invalidated: re-declared averager in background")
 
-        if maybe_next_leader is None:
-            await self.update_finished.wait()
-            return await self.pop_next_leader()
+            del self.leader_queue[maybe_next_leader]
+            self.past_attempts.add((maybe_next_leader, entry.expiration_time))
+            return maybe_next_leader
 
-        del self.leader_queue[maybe_next_leader]
-        return maybe_next_leader
+    @property
+    def request_expiration_time(self) -> float:
+        """ this averager's current expiration time - used to send join requests to leaders """
+        if isfinite(self.declared_expiration_time):
+            return self.declared_expiration_time
+        else:
+            return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
 
     async def _update_queue_periodically(self, group_key: GroupKey):
         DISCREPANCY = hivemind.utils.timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
@@ -352,14 +407,14 @@ class PotentialLeaders:
             new_peers = await self.dht.get_averagers(group_key, only_active=True, return_future=True)
             self.max_assured_time = max(self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY)
 
+            self.leader_queue.clear()
             for peer, peer_expiration_time in new_peers:
-                if peer == self.endpoint:
+                if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
                     continue
                 self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
                 self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
 
-            if len(self.leader_queue) > 0:
-                self.update_finished.set()
+            self.update_finished.set()
 
             await asyncio.wait(
                 {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
@@ -367,28 +422,31 @@ class PotentialLeaders:
             self.update_triggered.clear()
 
     async def _declare_averager_periodically(self, group_key: GroupKey):
-        try:
-            while True:
-                new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
-                self.declared_group_key, self.declared_expiration_time = group_key, new_expiration_time
-                stored_ok = await self.dht.declare_averager(group_key, self.endpoint, new_expiration_time,
-                                                            looking_for_group=True, return_future=True)
-                if stored_ok:
+        async with self.lock_declare:
+            try:
+                while True:
+                    await self.running.wait()
+
+                    new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
+                    self.declared_group_key, self.declared_expiration_time = group_key, new_expiration_time
+                    self.declared_expiration.set()
+                    await self.dht.declare_averager(group_key, self.endpoint, new_expiration_time,
+                                                    looking_for_group=True, return_future=True)
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
-                else:
-                    logger.warning(f"Failed to subscribe to group {group_key} : store rejected by DHT peers")
-        finally:
-            if self.declared_group_key is not None:
-                previous_declared_key, previous_expiration_time = self.declared_group_key, self.declared_expiration_time
-                self.declared_group_key, self.declared_expiration_time = None, float('inf')
-                self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float('-inf')
-                await self.dht.declare_averager(previous_declared_key, self.endpoint, previous_expiration_time,
-                                                looking_for_group=False, return_future=True)
+            except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
+                logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
+            finally:
+                if self.declared_group_key is not None:
+                    prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
+                    self.declared_group_key, self.declared_expiration_time = None, float('inf')
+                    self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float('-inf')
+                    await self.dht.declare_averager(prev_declared_key, self.endpoint, prev_expiration_time,
+                                                    looking_for_group=False, return_future=True)
 
 
 def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
     """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
     schema_dicts = [{field_name: str(field_value)
-                    for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
+                     for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
                     for tensor in tensors]
     return DHTID.generate(source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()

+ 7 - 9
hivemind/dht/__init__.py

@@ -27,7 +27,7 @@ from numpy import nextafter
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import get_dht_time, DHTValue
-from hivemind.utils import MPFuture, Endpoint, get_logger
+from hivemind.utils import MPFuture, Endpoint, get_logger, switch_to_uvloop
 
 logger = get_logger(__name__)
 
@@ -141,11 +141,7 @@ class DHT(mp.Process):
 
     def run(self) -> None:
         """ Serve DHT forever. This function will not return until DHT node is shut down """
-        if asyncio.get_event_loop().is_running():
-            asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
-        uvloop.install()
-        loop = asyncio.new_event_loop()
-        asyncio.set_event_loop(loop)
+        loop = switch_to_uvloop()
         pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
 
         async def _run():
@@ -497,13 +493,14 @@ class DHT(mp.Process):
     async def _declare_averager(self, node: DHTNode, *, group_key: str, endpoint: Endpoint,
                                 expiration_time: DHTExpiration, looking_for_group: bool, future: MPFuture):
         try:
-            expiration_time = expiration_time if looking_for_group else nextafter(expiration_time, float('inf'))
+            expiration_time = expiration_time if looking_for_group else float(nextafter(expiration_time, float('inf')))
             # ^-- when declaring averager inactive, we increment expiration time to overwrite the pre-existing entry
             store_ok = await node.store(
                 key=group_key, subkey=endpoint, value=looking_for_group, expiration_time=expiration_time)
             future.set_result(store_ok)
         except Exception as e:
-            future.set_exception(e)
+            if not future.done():
+                future.set_exception(e)
 
     def get_averagers(self, group_key: GroupKey, *, only_active: bool = True, return_future: bool = False
                       ) -> Union[List[Tuple[Endpoint, DHTExpiration]], MPFuture]:
@@ -534,4 +531,5 @@ class DHT(mp.Process):
                          if not only_active or entry.value is True]
             future.set_result(averagers)
         except Exception as e:
-            future.set_exception(e)
+            if not future.done():
+                future.set_exception(e)

+ 15 - 5
hivemind/dht/node.py

@@ -357,7 +357,7 @@ class DHTNode:
         """
         if latest:
             kwargs["sufficient_expiration_time"] = float('inf')
-        result = await self.get_many([key])
+        result = await self.get_many([key], **kwargs)
         return result[key]
 
     async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
@@ -579,10 +579,20 @@ class _SearchState:
     future: asyncio.Future[Optional[ValueWithExpiration[DHTValue]]] = field(default_factory=asyncio.Future)
     serializer: type(SerializerBase) = MSGPackSerializer
 
-    def add_candidate(self, candidate: Optional[ValueWithExpiration[BinaryDHTValue]], source_node_id: Optional[DHTID]):
-        binary_value, expiration_time = candidate or (None, -float('inf'))
-        if not self.finished and expiration_time > (self.expiration_time or -float('inf')):
-            self.binary_value, self.expiration_time, self.source_node_id = binary_value, expiration_time, source_node_id
+    def add_candidate(self, candidate: Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]],
+                      source_node_id: Optional[DHTID]):
+        if self.finished or candidate is None:
+            return
+        elif isinstance(candidate.value, DictionaryDHTValue) and isinstance(self.binary_value, DictionaryDHTValue):
+            self.binary_value.maxsize = max(self.binary_value.maxsize, candidate.value.maxsize)
+            for subkey, subentry in candidate.value.items():
+                self.binary_value.store(subkey, subentry.value, subentry.expiration_time)
+        elif candidate.expiration_time > (self.expiration_time or float('-inf')):
+            self.binary_value = candidate.value
+
+        if candidate.expiration_time > (self.expiration_time or float('-inf')):
+            self.expiration_time = candidate.expiration_time
+            self.source_node_id = source_node_id
             if self.expiration_time >= self.sufficient_expiration_time:
                 self.finish_search()
 

+ 1 - 1
hivemind/dht/protocol.py

@@ -44,7 +44,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """
         self = cls(_initialized_with_create=True)
         self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
-        self.wait_timeout, self.channel_options = wait_timeout, channel_options
+        self.wait_timeout, self.channel_options = wait_timeout, tuple(channel_options)
         self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))

+ 1 - 1
hivemind/proto/averaging.proto

@@ -5,7 +5,7 @@ import "runtime.proto";
 // Runs alongside each trainer to perform gating function averaging every now and then. Read more: client/averaging.py
 service DecentralizedAveraging {
   rpc rpc_join_group(JoinRequest) returns (stream MessageFromLeader);  // assemble a group for allreduce
-  rpc rpc_aggregate_part(AveragingData) returns (AveragingData);  // send my local shard => get aggregated shard
+  rpc rpc_aggregate_part(stream AveragingData) returns (stream AveragingData);  // send local part => get average part
 }
 
 enum MessageCode {

+ 1 - 0
hivemind/proto/runtime.proto

@@ -38,5 +38,6 @@ message Tensor {
   bool requires_grad = 3;
   string dtype = 4;
   CompressionType compression = 5;
+  int32 chunks = 6;
 }
 

+ 1 - 0
hivemind/utils/__init__.py

@@ -7,3 +7,4 @@ from hivemind.utils.threading import *
 from hivemind.utils.grpc import *
 from hivemind.utils.timed_storage import *
 from hivemind.utils.logging import get_logger
+from hivemind.utils.asyncio import *

+ 34 - 0
hivemind/utils/asyncio.py

@@ -0,0 +1,34 @@
+from typing import TypeVar, AsyncIterator, Union, AsyncIterable
+import asyncio
+import uvloop
+T = TypeVar('T')
+
+
+def switch_to_uvloop() -> asyncio.AbstractEventLoop:
+    """ stop any running event loops; install uvloop; then create, set and return a new event loop """
+    try:
+        asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
+    except RuntimeError as error_no_event_loop:
+        pass  # this allows running DHT from background threads with no event loop
+    uvloop.install()
+    loop = asyncio.new_event_loop()
+    asyncio.set_event_loop(loop)
+    return loop
+
+
+async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
+    """ equivalent to next(iter) for asynchronous iterators. Modifies aiter in-place! """
+    return await aiter.__anext__()
+
+
+async def aiter(*args: T) -> AsyncIterator[T]:
+    """ create an asynchronous iterator from a sequence of values """
+    for arg in args:
+        yield arg
+
+
+async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
+    """ equivalent to chain(iter1, iter2, ...) for asynchronous iterators. """
+    for aiter in async_iters:
+        async for elem in aiter:
+            yield elem

+ 29 - 4
hivemind/utils/grpc.py

@@ -2,19 +2,20 @@
 Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
 """
 from __future__ import annotations
+
 import os
 import threading
-from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type
+from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable
 
 import grpc
 import numpy as np
 import torch
+from hivemind.proto.runtime_pb2 import CompressionType
 
 from hivemind.proto import runtime_pb2
-from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithExpiration
-from hivemind.utils.networking import Endpoint
 from hivemind.utils.logging import get_logger
+from hivemind.utils.networking import Endpoint
+from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithExpiration
 
 logger = get_logger(__name__)
 
@@ -235,3 +236,27 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
 
     tensor.requires_grad_(serialized_tensor.requires_grad)
     return tensor
+
+
+def split_for_streaming(serialized_tensor: runtime_pb2.Tensor, chunk_size_bytes: int) -> Iterator[runtime_pb2.Tensor]:
+    """ Split serialized_tensor into multiple chunks for gRPC streaming """
+    buffer = memoryview(serialized_tensor.buffer)
+    num_chunks = len(range(0, len(buffer), chunk_size_bytes))
+    yield runtime_pb2.Tensor(
+        compression=serialized_tensor.compression, buffer=buffer[:chunk_size_bytes].tobytes(), chunks=num_chunks,
+        size=serialized_tensor.size, dtype=serialized_tensor.dtype, requires_grad=serialized_tensor.requires_grad)
+    for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
+        yield runtime_pb2.Tensor(buffer=buffer[chunk_start: chunk_start + chunk_size_bytes].tobytes())
+
+
+def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
+    """ Restore a result of split_into_chunks into a single serialized tensor """
+    stream = iter(stream)
+    first_chunk = next(stream)
+    serialized_tensor = runtime_pb2.Tensor()
+    serialized_tensor.CopyFrom(first_chunk)
+    buffer_chunks = [first_chunk.buffer]
+    for tensor_part in stream:
+        buffer_chunks.append(tensor_part.buffer)
+    serialized_tensor.buffer = b''.join(buffer_chunks)
+    return serialized_tensor

+ 1 - 1
hivemind/utils/mpfuture.py

@@ -54,7 +54,7 @@ class MPFuture(base.Future):
                 self.connection.close()
         except TimeoutError as e:
             raise e
-        except (BrokenPipeError, OSError) as e:
+        except (BrokenPipeError, OSError, EOFError) as e:
             if self._state in (base.PENDING, base.RUNNING):
                 self._state, self._exception = base.FINISHED, e
 

+ 4 - 0
hivemind/utils/tensor_descr.py

@@ -1,3 +1,4 @@
+import warnings
 from dataclasses import dataclass, asdict
 
 import torch
@@ -6,6 +7,9 @@ from hivemind.proto.runtime_pb2 import CompressionType
 
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 
+warnings.filterwarnings("ignore", "CUDA initialization*", category=UserWarning)
+# ^-- cures https://github.com/pytorch/pytorch/issues/47038
+
 
 @dataclass(init=True, repr=True, frozen=True)
 class DescriptorBase:

+ 5 - 0
hivemind/utils/timed_storage.py

@@ -81,6 +81,11 @@ class TimedStorage(Generic[KeyType, ValueType]):
             return top_key, self.data[top_key]
         return None, None
 
+    def clear(self):
+        self.data.clear()
+        self.key_to_heap.clear()
+        self.expiration_heap.clear()
+
     def __contains__(self, key: KeyType):
         self._remove_outdated()
         return key in self.data

+ 88 - 0
tests/benchmark_averaging.py

@@ -0,0 +1,88 @@
+import time
+import threading
+import argparse
+
+import torch
+import hivemind
+from hivemind.utils import LOCALHOST, increase_file_limit
+from hivemind.proto import runtime_pb2
+
+
+def sample_tensors(hid_size, num_layers):
+    tensors = []
+    for i in range(num_layers):
+        tensors.append(torch.randn(hid_size, 3 * hid_size))
+        tensors.append(torch.randn(3 * hid_size))
+        tensors.append(torch.randn(3 * hid_size))
+        tensors.append(torch.randn(hid_size, hid_size))
+        tensors.append(torch.ones(hid_size))
+        tensors.append(torch.zeros(hid_size))
+        tensors.append(torch.randn(hid_size, 4 * hid_size))
+        tensors.append(torch.randn(4 * hid_size))
+        tensors.append(torch.ones(4 * hid_size))
+        tensors.append(torch.randn(2, hid_size, hid_size, 2))
+        tensors.append(torch.randn(hid_size))
+        tensors.append(torch.randn(hid_size))
+        tensors.append(torch.randn(hid_size))
+    return tuple(tensors)
+
+
+def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
+                        averaging_expiration: float, request_timeout: float, round_timeout: float,
+                        hid_size: int, num_layers: int, spawn_dtime: float):
+    dht_root = hivemind.DHT(listen_on=f'{LOCALHOST}:*', start=True)
+    peer_tensors = [sample_tensors(hid_size, num_layers)
+                    for _ in range(num_peers)]
+    processes = {dht_root}
+
+    def run_averager(index):
+        dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
+                           initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
+                           start=True)
+        averager = hivemind.DecentralizedAverager(
+            peer_tensors[i], dht, prefix='my_tensor', initial_group_bits='0110', listen_on=f"{LOCALHOST}:*",
+            compression_type=runtime_pb2.CompressionType.FLOAT16, target_group_size=target_group_size,
+            averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)
+        processes.update({dht, averager})
+
+        print(end=f'<started {index}>\n', flush=True)
+        for _ in range(num_rounds):
+            success = averager.step(timeout=round_timeout)
+            print(end=('+' if success else '-'), flush=True)
+        print(end=f'<finished {index}>\n', flush=True)
+
+    threads = []
+    for i in range(num_peers):
+        thread = threading.Thread(target=run_averager, args=[i])
+        threads.append(thread)
+        thread.start()
+        time.sleep(spawn_dtime)
+
+    t = time.time()
+    for thread in threads:
+        thread.join()
+
+    print(f"\ntest run took {time.time() - t:.3f} seconds")
+
+    for process in processes:
+        process.terminate()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--num_peers', type=int, default=16, required=False)
+    parser.add_argument('--target_group_size', type=int, default=4, required=False)
+    parser.add_argument('--num_rounds', type=int, default=5, required=False)
+    parser.add_argument('--hid_size', type=int, default=256, required=False)
+    parser.add_argument('--num_layers', type=int, default=3, required=False)
+    parser.add_argument('--averaging_expiration', type=float, default=15, required=False)
+    parser.add_argument('--round_timeout', type=float, default=30, required=False)
+    parser.add_argument('--request_timeout', type=float, default=3, required=False)
+    parser.add_argument('--spawn_dtime', type=float, default=0.1, required=False)
+    parser.add_argument('--increase_file_limit', action="store_true")
+    args = vars(parser.parse_args())
+
+    if args.pop('increase_file_limit', False):
+        increase_file_limit()
+
+    benchmark_averaging(**args)

+ 10 - 9
tests/test_averaging.py

@@ -34,8 +34,7 @@ def test_getset_averagers():
 
 
 @pytest.mark.forked
-@pytest.mark.asyncio
-async def test_allreduce_once():
+def test_allreduce_once():
     dht = hivemind.DHT(start=True)
 
     tensors1 = [torch.randn(123), torch.zeros(3)]
@@ -52,12 +51,14 @@ async def test_allreduce_once():
 
     futures = []
     for averager in averagers:
-        futures.append(averager.step(return_future=True))  # TODO revert to hard version
-        time.sleep(0.5)
-
+        futures.append(averager.step(wait=False))
     for future in futures:
-        for ref, our in zip(reference, future.result()):
-            assert torch.allclose(ref, our)
+        assert future.result() is True
+
+    for averager in averagers:
+        with averager.get_tensors() as averaged_tensors:
+            for ref, our in zip(reference, averaged_tensors):
+                assert torch.allclose(ref, our, atol=1e-6)
 
 
 @pytest.mark.forked
@@ -90,7 +91,7 @@ async def test_allreduce_protocol():
     ]
 
     for peer, allreduce in zip(peers, allreduce_protocols):
-        assert allreduce.averaged_tensors.done()
+        assert allreduce.future.done()
         averaged_tensors = await allreduce
         assert len(averaged_tensors) == len(reference_tensors)
         assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
@@ -98,7 +99,7 @@ async def test_allreduce_protocol():
 
 
 @pytest.mark.forked
-def test_chunks():
+def test_partitioning():
     for _ in range(100):
         tensors = []
         for _ in range(random.randint(1, 5)):

+ 25 - 0
tests/test_util_modules.py

@@ -166,3 +166,28 @@ async def test_channel_cache():
         for j in range(i + 1, len(all_channels)):
             ci, cj = all_channels[i], all_channels[j]
             assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
+
+
+def test_serialize_tensor():
+    tensor = torch.randn(512, 12288)
+
+    serialized_tensor = hivemind.serialize_torch_tensor(tensor, hivemind.CompressionType.NONE)
+    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
+        chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
+        assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
+        restored = hivemind.combine_from_streaming(chunks)
+        assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)
+
+    chunk_size = 30 * 1024
+    serialized_tensor = hivemind.serialize_torch_tensor(tensor, hivemind.CompressionType.FLOAT16)
+    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
+    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
+    restored = hivemind.combine_from_streaming(chunks)
+    assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor, rtol=0, atol=1e-2)
+
+    tensor = torch.randint(0, 100, (512, 1, 1))
+    serialized_tensor = hivemind.serialize_torch_tensor(tensor, hivemind.CompressionType.NONE)
+    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
+    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
+    restored = hivemind.combine_from_streaming(chunks)
+    assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)