Explorar o código

Add per-tensor compression, make all-reduce faster and more flexible (#272)

* extract core components of all-reduce as TensorPartContainer / TensorPartition
* ensure that compression happens asynchronously and with background threads (per chunk)
* update AllReduceRunner for new partitioning 
* per-tensor compression in TensorPartContainer
* minimize memory allocation (e.g. use iterator in update_tensors)
* update DecentralizedAverager to use new AllReduceProtocol

Tests:
* test that partitioning recovers the original tensors
* test partitioning edge cases (e.g. empty tensors)
* test that partitioning is indeed asynchronous
* test new all-reduce protocol in separate file
* test asyncio utility functions
* benchmark performance under limited bandwidth (see PR discussion)


Co-authored-by: mponty <heapnhash@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Michael Diskin <yhn112@users.noreply.github.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic %!s(int64=4) %!d(string=hai) anos
pai
achega
0a0e290ea3

+ 1 - 1
docs/modules/client.rst

@@ -25,4 +25,4 @@
 .. autoclass:: DecentralizedAverager
    :members:
    :member-order: bysource
-   :exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part
+   :exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part, register_allreduce_group

+ 55 - 50
hivemind/client/averaging/__init__.py

@@ -20,6 +20,7 @@ import torch
 import numpy as np
 
 from hivemind.dht import DHT, DHTID
+from hivemind.client.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
@@ -34,9 +35,8 @@ from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescripto
 
 # flavour types
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
-DataForGather = Any
+GatheredData = Any
 logger = get_logger(__name__)
-DEFAULT_CHUNK_SIZE_BYTES = 2 ** 16
 
 
 class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
@@ -61,7 +61,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
-    :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
+    :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
     :param throughput: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
           If throughput == 0, averager will rely on its groupmates to do all the averaging.
@@ -94,8 +94,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
                  prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
-                 allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
+                 averaging_expiration: float = 15, request_timeout: float = 3, averaging_alpha: float = 1.0,
+                 part_size_bytes: int = DEFAULT_PART_SIZE_BYTES, allreduce_timeout: Optional[float] = None,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  throughput: Optional[float] = None, min_vector_size: int = 0,
                  auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
@@ -135,7 +135,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self.matchmaking_kwargs = dict(
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
             min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout)
-        self.allreduce_kwargs = dict(compression_type=compression_type, chunk_size_bytes=chunk_size_bytes,
+        self.allreduce_kwargs = dict(compression_type=compression_type, part_size_bytes=part_size_bytes,
                                      min_vector_size=min_vector_size)
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
@@ -251,8 +251,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         if self._parent_pid != os.getpid() or self.is_alive():
             self.shutdown()
 
-    def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, timeout: Optional[float] = None,
-             allow_retries: bool = True, wait: bool = True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
+    def step(self, gather: Optional[GatheredData] = None, weight: Optional[float] = None,
+             timeout: Optional[float] = None, allow_retries: bool = True, wait: bool = True
+             ) -> Union[Optional[Dict[Endpoint, GatheredData]], MPFuture]:
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
@@ -265,10 +266,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
-        if self.mode == AveragingMode.AUX and weight != 1:
+        if self.mode == AveragingMode.AUX and weight is not None:
             logger.warning("Averager is running in auxiliary mode, weight is unused.")
-        else:
-            assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
+        if weight is None:
+            weight = float(self.mode != AveragingMode.AUX)
+        assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
 
         future, _future = MPFuture.make_pair()
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
@@ -278,9 +280,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
                     allow_retries: bool, timeout: Optional[float]):
-        loop = asyncio.get_event_loop()
         start_time = get_dht_time()
-        group_id = None
 
         try:
             while not future.done():
@@ -291,16 +291,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                                                                         data_for_gather=data_for_gather)
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
-                    group_id = group_info.group_id
-                    allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
-                    self._running_groups[group_id] = allreduce_runner
-                    self._pending_group_assembled.set()
-                    await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
-                    if self.mode != AveragingMode.AUX:
-                        await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
 
-                    # averaging is finished, exit the loop
-                    future.set_result(allreduce_runner.gathered)
+                    future.set_result(await asyncio.wait_for(
+                        self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout))
+                    # averaging is finished, loop will now exit
 
                 except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
                         asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
@@ -311,10 +305,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     else:
                         logger.warning(f"Averager caught {repr(e)}, retrying")
 
-                finally:
-                    _ = self._running_groups.pop(group_id, None)
-                    self._pending_group_assembled.set()
-
         except BaseException as e:
             if not future.done():
                 future.set_exception(e)
@@ -324,35 +314,51 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
                                                   " Please report this to hivemind issues."))
 
-    async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
-        """ Use a group description found by Matchmaking to form AllreduceRunner """
+    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """ Run All-Reduce in a given group and update tensors in place, return gathered metadata """
         try:
             weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
-            # compute optimal part sizes from peer throughputs
             modes = tuple(map(AveragingMode, mode_ids))
-            incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)]  # TODO: replace with proper load balancing
-            part_sizes = await asyncio.get_event_loop().run_in_executor(
+
+            # compute optimal part sizes from peer throughputs; TODO: replace with proper load balancing
+            incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0
+                                    for thr, mode in zip(throughputs, modes)]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
                 None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
-            async with self.get_tensors_async() as averaged_tensors:
-                return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
-                                       ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
-                                       weights=weights, gathered=user_gathered, return_deltas=True, modes=modes, **kwargs)
-        except Exception as e:
-            raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {weights, throughputs, modes, user_gathered}")
 
-    def update_tensors(self, allreduce_group: AllReduceRunner):
-        """
-        a private (extendable) method that applies changes from a finished allreduce to local tensors
-        """
-        assert allreduce_group.return_deltas and allreduce_group.future.done()
-        averaging_deltas = allreduce_group.future.result()
+            async with self.get_tensors_async() as local_tensors:
+                allreduce = AllReduceRunner(
+                    group_id=group_info.group_id, tensors=local_tensors, endpoint=self.endpoint,
+                    ordered_group_endpoints=group_info.endpoints, peer_fractions=peer_fractions, weights=weights,
+                    gathered=user_gathered, modes=modes, **kwargs)
 
-        with torch.no_grad(), self.get_tensors() as local_tensors:
-            assert len(local_tensors) == len(self._averaged_tensors)
-            for tensor, update in zip(local_tensors, averaging_deltas):
-                tensor.add_(update, alpha=self._averaging_alpha)
-        self.last_updated = get_dht_time()
+                with self.register_allreduce_group(group_info.group_id, allreduce):
+
+                    # actually run all-reduce
+                    averaging_outputs = [output async for output in allreduce]
+
+                    if modes[group_info.endpoints.index(self.endpoint)] != AveragingMode.AUX:
+                        assert len(local_tensors) == len(self._averaged_tensors)
+                        for tensor, update in zip(local_tensors, averaging_outputs):
+                            tensor.add_(update, alpha=self._averaging_alpha)
+                        self.last_updated = get_dht_time()
+
+                return allreduce.gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+    @contextlib.contextmanager
+    def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
+        """ registers a given all-reduce runner to listen for incoming connections """
+        try:
+            self._running_groups[group_id] = allreduce
+            self._pending_group_assembled.set()
+            yield
+        finally:
+            self._running_groups.pop(group_id, None)
+            self._pending_group_assembled.set()
 
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
@@ -418,11 +424,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """
         if not self.allow_state_sharing:
             return  # deny request and direct peer to the next prospective averager
-        chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
         metadata, tensors = await self._get_current_state_from_host_process()
 
         for tensor in tensors:
-            for part in split_for_streaming(serialize_torch_tensor(tensor), chunk_size_bytes):
+            for part in split_for_streaming(serialize_torch_tensor(tensor)):
                 if metadata is not None:
                     yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
                     metadata = None

+ 159 - 206
hivemind/client/averaging/allreduce.py

@@ -1,14 +1,15 @@
 import asyncio
-from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any, Optional
+from typing import Sequence, Dict, Tuple, AsyncIterator, Any, Optional
 from enum import Enum
 
 import grpc
 import torch
 
-from hivemind.utils import Endpoint, get_logger, ChannelCache, anext
-from hivemind.utils import split_for_streaming, combine_from_streaming
+from hivemind.client.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
+from hivemind.utils import Endpoint, get_logger, ChannelCache
+from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.proto import averaging_pb2_grpc, runtime_pb2, averaging_pb2
+from hivemind.proto import averaging_pb2_grpc, averaging_pb2
 
 # flavour types
 GroupID = bytes
@@ -21,256 +22,208 @@ class AveragingMode(Enum):
     AUX = 2
 
 
-class AllReduceProtocol:
+class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
     """
     An internal class that runs butterfly AllReduce in a predefined group of averagers
 
+    :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
+    :param group_id: unique identifier of this specific all-reduce run
+    :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param endpoint: your endpoint, must be included in ordered_group_endpoints
     :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
-    :param part_sizes: for each peer, a number of vector elements that this peer is responsible for averaging
-    :param return_deltas: if True, returns the element-wise differences (averaged_tensors - original_tensors)
-           default (False) - return averaged_tensors by themselves
+    :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
+      (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
+    :param modes: AveragingMode for each peer in ordered_group_endpoints (normal, client-only or auxiliary)
+    :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
+    :param gathered: additional user-defined data collected from this group
+    :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
     """
 
-    def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False,
-                 modes: Optional[Sequence[AveragingMode]] = None):
+    def __init__(
+            self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
+            ordered_group_endpoints: Sequence[Endpoint], peer_fractions: Tuple[float, ...],
+            weights: Optional[Sequence[float]] = None, modes: Optional[Sequence[AveragingMode]] = None,
+            gathered: Optional[Dict[Endpoint, Any]] = None, **kwargs):
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
-        self.group_id, self.endpoint = group_id, endpoint
-        self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
-        if modes is None:
-            modes = [AveragingMode.CLIENT if part_size == 0 else AveragingMode.NODE for part_size in part_sizes]
-        assert any(mode != AveragingMode.CLIENT for mode in modes), "Cannot run allreduce without reducers."
-        self.peer_modes = dict(zip(ordered_group_endpoints, modes))
-
-        self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
-        self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
-        self.return_deltas = return_deltas
-
-        self.accumulator = torch.zeros_like(self.local_tensor_parts[self.endpoint])
-        self.denominator = 0.0  # number of peers added to accumulator or sum of their weights
-        self.accumulated_from: Set[Endpoint] = set()  # peers that we have accumulated our part from
-        self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
-        self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
-        self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
-
-        self.num_senders = len([mode for mode in modes if mode != AveragingMode.AUX])
-
-        if self.num_senders == 0:
-            self.future.set_result(None)
-        for endpoint, mode in self.peer_modes.items():
-            if mode == AveragingMode.CLIENT:
-                self.averaged_tensor_parts[endpoint] = torch.tensor([])
+        modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
+        weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
+        assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
+        assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
+        for mode, frac, weight in zip(modes, peer_fractions, weights):
+            assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
+            assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
+
+        self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
+        self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
+
+        self._future = asyncio.Future()
+
+        self.sender_endpoints, self.sender_weights = [], []
+        for endpoint, weight, mode in zip(self.ordered_group_endpoints, weights, modes):
+            if mode != AveragingMode.AUX:
+                self.sender_endpoints.append(endpoint)
+                self.sender_weights.append(weight)
+
+        endpoint_index = self.ordered_group_endpoints.index(self.endpoint)
+        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
+        self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(endpoint_index)
+        self.tensor_part_reducer = TensorPartReducer(tuple(part.shape for part in self.parts_for_local_averaging),
+                                                     len(self.sender_endpoints), self.sender_weights)
 
     def __repr__(self):
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
 
-    def __await__(self):
-        return self.future.__await__()
+    def __aiter__(self):
+        return self.run()
 
     def __contains__(self, endpoint: Endpoint):
-        return endpoint in self.local_tensor_parts
+        return endpoint in self.ordered_group_endpoints
 
     @property
     def group_size(self):
         return len(self.ordered_group_endpoints)
 
-    async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor, weight: float = 1.0) -> torch.Tensor:
-        """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
-        assert not self.averaged_part.done(), f"already finished averaging part: {self.averaged_part}"
-        assert not self.future.done(), f"already finished allreduce: {self.future}"
-        assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
-        assert source not in self.accumulated_from, "duplicate source, already received that part"
-        assert self.peer_modes[self.endpoint] != AveragingMode.CLIENT, f"{self.endpoint} is in AveragingMode.client mode"
-        assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float"
-
-        logger.debug(f"{self} - accumulating tensor part from {source}")
-        self.accumulator.add_(remote_part, alpha=weight)
-        self.denominator += weight
-        self.accumulated_from.add(source)
-
-        assert len(self.accumulated_from) <= self.num_senders
-        if len(self.accumulated_from) == self.num_senders:
-            average_result = self.accumulator.div_(self.denominator)
-            self.averaged_part.set_result(average_result)
-
-            if self.peer_modes[self.endpoint] == AveragingMode.AUX:
-                self.future.set_result(None)  # auxiliary mode has finished averaging
-            else:
-                self.register_averaged_part(self.endpoint, average_result)
-
-        return await self.averaged_part
-
-    def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
-        assert not self.future.done(), f"already finished allreduce: {self.future}"
-        assert source in self.local_tensor_parts, "the provider of averaged part is not from my group"
-        assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
-        assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
-        assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
-        assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers do not have local tensors for sending"
-        logger.debug(f"{self} - receiving averaged tensor part from {source}")
-        self.averaged_tensor_parts[source] = averaged_part
-        if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
-            ordered_averaged_parts = [self.averaged_tensor_parts[endpoint] for endpoint in self.ordered_group_endpoints]
-            outputs = restore_from_parts(ordered_averaged_parts, self.tensor_shapes)
-
-            if self.return_deltas:
-                local_parts = [self.local_tensor_parts[peer] for peer in self.ordered_group_endpoints]
-                with torch.no_grad():
-                    original_tensors = restore_from_parts(local_parts, self.tensor_shapes)
-                    for averaged_tensor, original_tensor in zip(outputs, original_tensors):
-                        averaged_tensor -= original_tensor
-
-            self.future.set_result(outputs)
-
-    def cancel(self) -> bool:
-        if not self.future.done():
-            logger.debug(f"{self} - cancelled")
-            self.future.cancel()
-            if not self.averaged_part.done():
-                self.averaged_part.cancel()
-            return True
-        else:
-            logger.debug(f"{self} - failed to cancel, allreduce is already finished: {self.future}")
-            return False
-
-    def set_exception(self, exception: Exception) -> bool:
-        if not self.future.done():
-            logger.debug(f"{self} - {exception}")
-            self.future.set_exception(exception)
-            if not self.averaged_part.done():
-                self.averaged_part.cancel()
-            return True
-        else:
-            logger.debug(f"{self} - failed to set {exception}, allreduce already finished: {self.future}")
-            return False
-
-
-class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragingServicer):
-    """
-    A class that implements ButterflyAllReduceProtocol on top of a gRPC servicer
-    """
-
-    def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
-                 chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...],
-                 gathered: Dict[Endpoint, Any], return_deltas: bool = False, **kwargs):
-        super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
-                         ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas, **kwargs)
-        self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
-        self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))
-
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
 
-    async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
-        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
-        assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors"
-        if peer_endpoint == self.endpoint:
-            return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint])
-        serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
-        chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes)
-
-        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-        await stream.write(averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id,
-                                                       endpoint=self.endpoint, tensor_part=next(chunks)))
-        for chunk in chunks:
-            await stream.write(averaging_pb2.AveragingData(tensor_part=chunk))
-        await stream.done_writing()
-
-        outputs: Sequence[averaging_pb2.AveragingData] = [message async for message in stream]
-        code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR
-        if code != averaging_pb2.AVERAGED_PART:
-            raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}"
-                                     f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
-                                     f" allreduce failed")
-
+    async def run(self) -> AsyncIterator[torch.Tensor]:
+        """ Run all-reduce, return differences between averaged and original tensors as they are computed """
+        pending_tasks = set()
         try:
-            averaged_part = local_part + deserialize_torch_tensor(combine_from_streaming(
-                [message.tensor_part for message in outputs]))
-        except RuntimeError as e:
-            raise AllreduceException(f"Could not deserialize averaged part from {peer_endpoint}: {e}")
+            if len(self.sender_endpoints) == 0:
+                logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
+                self.finalize()
 
-        self.register_averaged_part(peer_endpoint, averaged_part)
-        return averaged_part
+            elif self.endpoint in self.sender_endpoints:
+                for endpoint, parts in zip(self.ordered_group_endpoints, self.tensor_part_container.num_parts_by_peer):
+                    if parts != 0:
+                        pending_tasks.add(asyncio.create_task(self._communicate_with_peer(endpoint)))
 
-    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
-        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
-        await stream.done_writing()
+                async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
+                    yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
+                self.finalize()
+
+            else:  # auxiliary peer
+                await self.tensor_part_reducer.finished.wait()
+                self.finalize()
 
-    async def run(self) -> Sequence[torch.Tensor]:
-        """
-        send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
-        """
-        try:
-            if self.peer_modes[self.endpoint] != AveragingMode.AUX:
-                await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
-                                            for i, peer in enumerate(self.ordered_group_endpoints)
-                                            if self.peer_modes[peer] != AveragingMode.CLIENT))
-            return await self
         except BaseException as e:
+            self.finalize(exception=e)
+            for task in pending_tasks:
+                task.cancel()
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            self.set_exception(e)
-            for peer_endpoint, part_size in zip(self.ordered_group_endpoints, self.part_sizes):
-                if peer_endpoint != self.endpoint and part_size > 0:
+            for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
+                if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
                     asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
             raise
 
-    async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
-                                        ) -> Iterable[runtime_pb2.Tensor]:
-        """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
-        try:
-            tensor_part = deserialize_torch_tensor(combine_from_streaming(stream_messages))
-        except RuntimeError as e:
-            raise AllreduceException(f"Could not deserialize tensor part from {source} for streaming {e}")
+    async def _communicate_with_peer(self, peer_endpoint: Endpoint):
+        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
+        peer_index = self.ordered_group_endpoints.index(peer_endpoint)
+        if peer_endpoint == self.endpoint:
+            sender_index = self.sender_endpoints.index(peer_endpoint)
+            for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
+                averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+                self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
-        averaged_part = await self.accumulate_part(source, tensor_part, weight=self.peer_weights[source])
-        serialized_tensor = serialize_torch_tensor(averaged_part - tensor_part, self.compression_type, allow_inplace=False)
-        stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
-        return stream_chunks
+        else:
+            loop = asyncio.get_event_loop()
+            stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
+            write_task = asyncio.create_task(self._write_to_peer(stream, peer_index))
+
+            try:
+                code = None
+                async for part_index, msg in aenumerate(stream):
+                    if code is None:
+                        code = msg.code
+                    averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
+                    self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
+                await write_task
+
+                if code != averaging_pb2.AVERAGED_PART:
+                    raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
+                                             f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
+                                             f", allreduce failed")
+            finally:
+                if not write_task.done():
+                    write_task.cancel()
+
+    async def _write_to_peer(self, stream: grpc.aio.StreamStreamCall, peer_index: int):
+        parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
+        first_part = await anext(parts_aiter)
+        await stream.write(averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING,
+                                                       group_id=self.group_id, endpoint=self.endpoint,
+                                                       tensor_part=first_part))
+        async for part in parts_aiter:
+            await stream.write(averaging_pb2.AveragingData(tensor_part=part))
+
+        await stream.done_writing()
 
     async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
                                  ) -> AsyncIterator[averaging_pb2.AveragingData]:
-        """ a groupmate sends us a part of his tensor; we should average it with other peers and return the delta"""
+        """ a peer sends us a part of his tensor; we should average it with other peers and return the difference """
         request: averaging_pb2.AveragingData = await anext(stream)
-
-        if request.group_id != self.group_id:
-            yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+        reason_to_reject = self._check_reasons_to_reject(request)
+        if reason_to_reject:
+            yield reason_to_reject
+            return
 
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
-                tensor_chunks = (request.tensor_part, *[msg.tensor_part async for msg in stream])
-                averaged_chunks = iter(await self.accumulate_part_streaming(request.endpoint, tensor_chunks))
-                yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=next(averaged_chunks))
-                for averaged_chunk in averaged_chunks:
-                    yield averaging_pb2.AveragingData(tensor_part=averaged_chunk)
+                sender_index = self.sender_endpoints.index(request.endpoint)
+                async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
+                    yield msg
 
             except Exception as e:
-                self.set_exception(e)
+                self.finalize(exception=e)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
         else:
             error_code = averaging_pb2.MessageCode.Name(request.code)
             logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
-            self.set_exception(AllreduceException(f"peer {request.endpoint} sent {error_code}."))
+            self.finalize(exception=AllreduceException(f"peer {request.endpoint} sent {error_code}."))
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
 
+    def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
+        if request.group_id != self.group_id:
+            return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+        elif self._future.cancelled():
+            return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
+        elif self._future.done():
+            return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+
+    async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
+        loop = asyncio.get_event_loop()
+        async for part_index, (tensor_part, part_compression) in aenumerate(
+                amap_in_executor(lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression), stream,
+                                 max_prefetch=self.tensor_part_container.prefetch)):
+            averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+
+            serialized_delta = await loop.run_in_executor(
+                None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression))
+            yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
-def split_into_parts(tensors: Sequence[torch.Tensor], part_sizes: Tuple[int, ...]) -> Tuple[torch.Tensor, ...]:
-    """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
-    flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
-    return torch.split_with_sizes(flat_tensor, part_sizes, dim=0)
-
-
-def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:
-    """ restores the original tensor shapes from chunks obtained by split_into_chunks """
-    flat_tensor = torch.cat(tuple(chunks))
-    result_sizes = tuple(map(torch.Size.numel, shapes))
-    flat_original_tensors = torch.split_with_sizes(flat_tensor, result_sizes)
-    return tuple(map(torch.Tensor.reshape, flat_original_tensors, shapes))
-
+    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
+        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
+        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
+        await stream.done_writing()
 
-class AllreduceException(Exception):
-    """ A special exception that is raised when allreduce can't continue normally (e.g. disbanded/bad request/etc) """
+    def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
+        assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
+        if not self._future.done():
+            if cancel:
+                logger.debug(f"{self} - cancelled")
+                self._future.cancel()
+            elif exception:
+                logger.debug(f"{self} - caught {exception}")
+                self._future.set_exception(exception)
+            else:
+                logger.debug(f"{self} - finished")
+                self._future.set_result(None)
+            self.tensor_part_container.finalize()
+            self.tensor_part_reducer.finalize()
+            return True
+        else:
+            logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
+            return False

+ 1 - 0
hivemind/client/averaging/load_balancing.py

@@ -28,6 +28,7 @@ def load_balance_peers(vector_size, throughputs: Sequence[Optional[float]], min_
         assert not all(throughput == 0 for throughput in throughputs), "Must have at least one nonzero throughput"
         scores = np.asarray([1.0 if throughput is None else 0.0 for throughput in throughputs])
 
+    #TODO(jheuristic) we no longer need hagenbach-bishoff with new AllReduceRunner
     return tuple(hagenbach_bishoff(vector_size, scores))
 
 

+ 224 - 0
hivemind/client/averaging/partition.py

@@ -0,0 +1,224 @@
+"""
+Auxiliary data structures for AllReduceRunner
+"""
+import asyncio
+from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator
+from collections import deque
+
+import torch
+import numpy as np
+
+from hivemind.proto.runtime_pb2 import CompressionType, Tensor
+from hivemind.utils.compression import serialize_torch_tensor, get_nbytes_per_value
+from hivemind.utils.asyncio import amap_in_executor
+
+
+T = TypeVar('T')
+DEFAULT_PART_SIZE_BYTES = 2 ** 20
+
+
+class TensorPartContainer:
+    """
+    Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
+    The class is designed to avoid excessive memory allocation and run all heavy computation in background
+    :param tensors: local tensors to be split and aggregated
+    :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
+    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
+    :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
+    :param prefetch: when compressing, pre-compute this many compressed tensors in background
+    """
+
+    def __init__(self, tensors: Sequence[torch.Tensor], peer_fractions: Sequence[float],
+                 compression_type: Union[type(CompressionType), Sequence[type(CompressionType)]] = CompressionType.NONE,
+                 part_size_bytes: int = 2 ** 20, prefetch: int = 1):
+        if not isinstance(compression_type, Sequence):
+            compression_type = [compression_type] * len(tensors)
+        assert len(compression_type) == len(tensors), "compression types do not match the number of tensors"
+        self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
+        self.compression_type, self.part_size_bytes, self.prefetch = compression_type, part_size_bytes, prefetch
+        self.total_size = sum(tensor.numel() for tensor in tensors)
+        self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
+        self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
+        self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
+        self._output_part_available = [asyncio.Event() for _ in range(self.group_size)]
+        self._outputs_registered_by_peer = [0 for _ in range(self.group_size)]
+        self._outputs_consumed = False
+        self.finished = asyncio.Event()
+        self.num_parts_by_tensor = []
+
+        # split tensor parts in proportion to target_size_by_peer
+        current_length = 0
+        current_peer_index = 0
+        pivots = (np.cumsum(peer_fractions) / np.sum(peer_fractions) * self.total_size).astype(np.int64)
+        pivots[-1] = self.total_size
+
+        for tensor, tensor_compression in zip(self.local_tensors, compression_type):
+            part_size_values = int(part_size_bytes / get_nbytes_per_value(tensor.dtype, tensor_compression))
+            tensor_parts = tensor.detach().view(-1).split(part_size_values)
+            self.num_parts_by_tensor.append(len(tensor_parts))
+            for part in tensor_parts:
+                if current_length + len(part) > pivots[current_peer_index]:
+                    # switch to next peer; if a part lands between parts of two or
+                    # more peers, assign that part to the peer with highest intersection
+                    prev_peer_index = current_peer_index
+                    peer_intersections = [pivots[current_peer_index] - current_length]
+                    while current_length + len(part) > pivots[current_peer_index]:
+                        current_peer_index += 1
+                        current_peer_part_end = min(current_length + len(part), pivots[current_peer_index])
+                        peer_intersections.append(current_peer_part_end - pivots[current_peer_index - 1])
+                    assigned_peer_index = prev_peer_index + np.argmax(peer_intersections)
+                    self._input_parts_by_peer[assigned_peer_index].append((part, tensor_compression))
+                else:
+                    self._input_parts_by_peer[current_peer_index].append((part, tensor_compression))
+                current_length += len(part)
+
+        assert current_length == self.total_size
+        self.num_parts_by_peer = tuple(len(parts) for parts in self._input_parts_by_peer)
+
+    @torch.no_grad()
+    def get_raw_input_parts(self, peer_index: int) -> Tuple[torch.Tensor, ...]:
+        """ get non-serialized tensor parts for a peer at a given index """
+        assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
+        self._inputs_consumed_by_peer[peer_index] = True
+        input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
+        self._input_parts_by_peer[peer_index].clear()
+        return input_parts
+
+    @torch.no_grad()
+    async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[Tensor]:
+        """ iterate serialized tensor parts for a peer at a given index. Run serialization in background. """
+        assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
+        self._inputs_consumed_by_peer[peer_index] = True
+
+        async def _aiterate_parts():
+            for _ in range(self.num_parts_by_peer[peer_index]):
+                yield self._input_parts_by_peer[peer_index].popleft()
+
+        async for serialized_part in amap_in_executor(lambda x_and_compr: serialize_torch_tensor(*x_and_compr),
+                                                      _aiterate_parts(), max_prefetch=self.prefetch):
+            yield serialized_part
+
+    def register_processed_part(self, peer_index: int, part_index: int, part: torch.Tensor):
+        """
+        register next-in-line part of results received from a given peer for use in iterate_output_tensors
+        depending on the algorithm, processed part is an average, difference from average or another aggregation
+        """
+        if part_index != self._outputs_registered_by_peer[peer_index]:
+            raise ValueError(f"Could not register part #{part_index} from peer #{peer_index}, "
+                             f" expected part index: {self._outputs_registered_by_peer[peer_index]}")
+        self._output_parts_by_peer[peer_index].append(part)
+        self._outputs_registered_by_peer[peer_index] += 1
+        self._output_part_available[peer_index].set()
+
+    async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
+        """ iterate over the outputs of averaging (whether they are average, delta or other aggregation result) """
+        assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
+        self._outputs_consumed = True
+        peer_index = num_parts_processed = 0
+        for tensor_index in range(len(self.local_tensors)):
+            tensor_parts = []
+            while len(tensor_parts) < self.num_parts_by_tensor[tensor_index]:
+                if num_parts_processed >= self.num_parts_by_peer[peer_index]:
+                    num_parts_processed = 0
+                    peer_index += 1
+                    continue
+                if not self._output_parts_by_peer[peer_index]:
+                    self._output_part_available[peer_index].clear()
+                    await self._output_part_available[peer_index].wait()
+                    if self.finished.is_set():
+                        raise AllreduceException("All-reduce was terminated during iteration.")
+
+                tensor_parts.append(self._output_parts_by_peer[peer_index].popleft())
+                num_parts_processed += 1
+            tensor = torch.cat(tensor_parts)
+            del tensor_parts
+            yield tensor.reshape(self.local_tensors[tensor_index].shape)
+
+    def __del__(self):
+        self.finalize()
+
+    def finalize(self):
+        """ terminate all iterators, delete intermediate data """
+        if not self.finished.is_set():
+            for peer_index in range(self.group_size):
+                self._inputs_consumed_by_peer[peer_index] = True
+                self._input_parts_by_peer[peer_index].clear()
+                self._output_parts_by_peer[peer_index].clear()
+                self._output_part_available[peer_index].set()
+            self._outputs_consumed = True
+            self.finished.set()
+
+
+class TensorPartReducer:
+    """
+    Auxiliary data structure responsible for running asynchronous all-reduce
+    :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
+    :param num_senders: total number of peers in a given all-reduce group that will send gradients
+    :param weights: relative importance of each sender, used for weighted average (default = equal weights)
+    :note: even if local peer is not sending data, local parts will be used for shape information
+    """
+
+    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int,
+                 weights: Optional[Sequence[float]] = None):
+        self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
+        self.weights = tuple(weights or (1 for _ in range(num_senders)))
+        assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
+        assert all(isinstance(weight, (int, float)) for weight in self.weights)
+        self.current_part_index = -1  # index in local_parts of the part that should be loaded next
+        self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
+        self.accumulator = None  # this will contain the sum of current tensor part from group peers
+        self.denominator = 0.0  # total weight accumulated from all peers for current part
+        self.current_part_future = asyncio.Future()
+        self.finished = asyncio.Event()
+        self.reset_accumulators()
+
+    def reset_accumulators(self):
+        """ (re)create averaging buffers for the next part in line, prepopulate with local tensor part """
+        assert self.current_part_accumulated_from == self.num_senders or self.current_part_index == -1
+        if self.current_part_index >= self.num_parts - 1:
+            self.finalize()
+            return
+
+        self.current_part_index += 1
+        self.current_part_accumulated_from = 0
+        self.current_part_future = asyncio.Future()
+        self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
+        self.denominator = 0.0
+
+    async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
+        """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
+        assert 0 <= sender_index < self.num_senders, "invalid sender index"
+        assert 0 <= part_index < self.num_parts, "invalid part index"
+
+        while part_index > self.current_part_index:
+            # wait for previous parts to finish processing ...
+            await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
+            if self.finished.is_set():
+                raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")
+        assert part_index == self.current_part_index
+
+        current_part_future = self.current_part_future
+
+        self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
+        self.denominator += self.weights[sender_index]
+        self.current_part_accumulated_from += 1
+
+        assert self.current_part_accumulated_from <= self.num_senders
+        if self.current_part_accumulated_from == self.num_senders:
+            current_part_future.set_result(self.accumulator.div_(self.denominator))
+            self.reset_accumulators()
+        return await current_part_future
+
+    def finalize(self):
+        if not self.finished.is_set():
+            if hasattr(self, 'current_part_future'):
+                self.current_part_future.cancel()
+                del self.accumulator
+            self.finished.set()
+
+    def __del__(self):
+        self.finalize()
+
+
+class AllreduceException(Exception):
+    """ A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error) """

+ 1 - 1
hivemind/proto/averaging.proto

@@ -43,7 +43,7 @@ message MessageFromLeader {
   bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
   string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
   repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
-  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their endoints
+  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their endpoints
 }
 
 message AveragingData {

+ 49 - 1
hivemind/utils/asyncio.py

@@ -1,7 +1,14 @@
-from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable
+from concurrent.futures import ThreadPoolExecutor
+from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable, Tuple, Optional, Callable
 import asyncio
+
 import uvloop
+
+from hivemind.utils.logging import get_logger
+
+
 T = TypeVar('T')
+logger = get_logger(__name__)
 
 
 def switch_to_uvloop() -> asyncio.AbstractEventLoop:
@@ -27,6 +34,16 @@ async def aiter(*args: T) -> AsyncIterator[T]:
         yield arg
 
 
+async def azip(*iterables: AsyncIterable[T]) -> AsyncIterator[Tuple[T, ...]]:
+    """ equivalent of zip for asynchronous iterables """
+    iterators = [iterable.__aiter__() for iterable in iterables]
+    while True:
+        try:
+            yield tuple(await asyncio.gather(*(itr.__anext__() for itr in iterators)))
+        except StopAsyncIteration:
+            break
+
+
 async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
     """ equivalent to chain(iter1, iter2, ...) for asynchronous iterators. """
     for aiter in async_iters:
@@ -34,6 +51,14 @@ async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
             yield elem
 
 
+async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]]:
+    """ equivalent to enumerate(iter) for asynchronous iterators. """
+    index = 0
+    async for elem in aiterable:
+        yield index, elem
+        index += 1
+
+
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
         await awaitable
@@ -42,3 +67,26 @@ async def await_cancelled(awaitable: Awaitable) -> bool:
         return True
     except BaseException:
         return False
+
+
+async def amap_in_executor(func: Callable[..., T], *iterables: AsyncIterable, max_prefetch: Optional[int] = None,
+                           executor: Optional[ThreadPoolExecutor] = None) -> AsyncIterator[T]:
+    """ iterate from an async iterable in a background thread, yield results to async iterable """
+    loop = asyncio.get_event_loop()
+    queue = asyncio.Queue(max_prefetch)
+
+    async def _put_items():
+        async for args in azip(*iterables):
+            await queue.put(loop.run_in_executor(executor, func, *args))
+        await queue.put(None)
+
+    task = asyncio.create_task(_put_items())
+    try:
+        future = await queue.get()
+        while future is not None:
+            yield await future
+            future = await queue.get()
+        await task
+    finally:
+        if not task.done():
+            task.cancel()

+ 12 - 0
hivemind/utils/compression.py

@@ -188,3 +188,15 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
 
     tensor.requires_grad_(serialized_tensor.requires_grad)
     return tensor
+
+
+def get_nbytes_per_value(dtype: torch.dtype, compression: CompressionType) -> int:
+    """ returns the number of bytes per value for a given tensor (excluding metadata) """
+    if compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
+        return 1
+    elif compression in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT):
+        return 2
+    elif compression == CompressionType.NONE:
+        return torch.finfo(dtype).bits // 8
+    else:
+        raise NotImplementedError(f"Unknown compression type: {CompressionType.Name(compression)}")

+ 5 - 1
hivemind/utils/grpc.py

@@ -158,7 +158,11 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
         raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
 
 
-def split_for_streaming(serialized_tensor: runtime_pb2.Tensor, chunk_size_bytes: int) -> Iterator[runtime_pb2.Tensor]:
+STREAMING_CHUNK_SIZE_BYTES = 2 ** 16
+
+
+def split_for_streaming(serialized_tensor: runtime_pb2.Tensor, chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
+                        ) -> Iterator[runtime_pb2.Tensor]:
     """ Split serialized_tensor into multiple chunks for gRPC streaming """
     buffer = memoryview(serialized_tensor.buffer)
     num_chunks = len(range(0, len(buffer), chunk_size_bytes))

+ 217 - 0
tests/test_allreduce.py

@@ -0,0 +1,217 @@
+import asyncio
+import random
+import time
+from typing import Sequence
+
+import pytest
+import torch
+import grpc
+
+from hivemind import aenumerate, Endpoint
+from hivemind.client.averaging.allreduce import AllReduceRunner, AveragingMode
+from hivemind.client.averaging.partition import TensorPartContainer, TensorPartReducer
+from hivemind.utils import deserialize_torch_tensor, ChannelCache
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.proto import averaging_pb2_grpc
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_partitioning():
+    all_tensors = [
+        torch.randn(30_000, 128), torch.rand(128), torch.ones(1, 1, 1, 1, 1, 1, 8),
+        torch.ones(1, 0), torch.zeros(0), torch.zeros([]), torch.randn(65536),
+        torch.rand(512, 2048), torch.randn(1024, 1024).add(-9), torch.zeros(1020), torch.randn(4096)
+    ]
+
+    # note: this test does _not_ use parameterization to reuse sampled tensors
+    for num_tensors in 1, 3, 5:
+        for part_size_bytes in 31337, 2 ** 20, 10 ** 10:
+            for weights in [(1, 1), (0.333, 0.1667, 0.5003), (1.0, 0.0), [0.0, 0.4, 0.6, 0.0]]:
+                tensors = random.choices(all_tensors, k=num_tensors)
+                partition = TensorPartContainer(tensors, weights, part_size_bytes=part_size_bytes)
+
+                async def write_tensors():
+                    for peer_index in range(partition.group_size):
+                        async for part_index, part in aenumerate(partition.iterate_input_parts_for(peer_index)):
+                            output_tensor = torch.sin(deserialize_torch_tensor(part))
+                            partition.register_processed_part(peer_index, part_index, output_tensor)
+
+                task = asyncio.create_task(write_tensors())
+                tensor_index = 0
+                async for output_tensor in partition.iterate_output_tensors():
+                    assert torch.allclose(output_tensor, torch.sin(tensors[tensor_index]))
+                    tensor_index += 1
+                assert tensor_index == len(tensors)
+                await task
+
+
+@pytest.mark.parametrize("tensors", [[torch.zeros(0)], [torch.zeros(0), torch.zeros(0), torch.zeros(1)],
+                                     [torch.zeros(0), torch.zeros(999), torch.zeros(0), torch.zeros(0)]])
+@pytest.mark.parametrize("peer_fractions", [(0.33, 0.44, 0.23), (0.5, 0.5), (0.1, 0.0, 0.9), (1.0,), (0.1,) * 9])
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_partitioning_edge_cases(tensors: Sequence[torch.Tensor], peer_fractions: Sequence[float]):
+    partition = TensorPartContainer(tensors, peer_fractions, part_size_bytes=16)
+    for peer_index in range(len(peer_fractions)):
+        async for part_index, part in aenumerate(partition.iterate_input_parts_for(peer_index)):
+            partition.register_processed_part(peer_index, part_index, deserialize_torch_tensor(part))
+
+    tensor_index = 0
+    async for output_tensor in partition.iterate_output_tensors():
+        assert torch.allclose(output_tensor, tensors[tensor_index])
+        tensor_index += 1
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_partitioning_asynchronous():
+    """ ensure that tensor partitioning does not interfere with asynchronous code """
+    tensors = [torch.randn(2048, 2048), torch.randn(1024, 4096),
+               torch.randn(4096, 1024), torch.randn(30_000, 1024)]
+    peer_fractions = [0.4, 0.3, 0.2, 0.1]
+
+    partition = TensorPartContainer(tensors, peer_fractions, compression_type=CompressionType.QUANTILE_8BIT)
+    read_started, read_finished = asyncio.Event(), asyncio.Event()
+
+    async def write_tensors():
+        for peer_index in range(partition.group_size):
+            async for part_index, part in aenumerate(partition.iterate_input_parts_for(peer_index)):
+                partition.register_processed_part(peer_index, part_index, deserialize_torch_tensor(part))
+        assert read_started.is_set(), "partitioner should have started reading before it finished writing"
+
+    async def read_tensors():
+        async for _ in partition.iterate_output_tensors():
+            read_started.set()
+        read_finished.set()
+
+    async def wait_synchronously():
+        time_in_waiting = 0.0
+        while not read_finished.is_set():
+            await asyncio.sleep(0.01)
+            time_in_waiting += 0.01
+        return time_in_waiting
+
+    start_time = time.perf_counter()
+    *_, time_in_waiting = await asyncio.gather(write_tensors(), read_tensors(), wait_synchronously())
+    wall_time = time.perf_counter() - start_time
+    # check that event loop had enough time to respond to incoming requests; this is over 50% most of the time
+    # we set 33% threshold to ensure that the test will pass reliably. If we break prefetch, this drops to <10%
+    assert time_in_waiting > wall_time / 3, f"Event loop could only run {time_in_waiting / wall_time :.5f} of the time"
+
+
+@pytest.mark.parametrize("num_senders", [1, 2, 4, 10])
+@pytest.mark.parametrize("num_parts", [0, 1, 100])
+@pytest.mark.parametrize("synchronize_prob", [1.0, 0.1, 0.0])
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float):
+    tensor_part_shapes = [torch.Size([i]) for i in range(num_parts)]
+    reducer = TensorPartReducer(tensor_part_shapes, num_senders)
+
+    local_tensors_by_sender = [[torch.randn(i) for i in range(num_parts)]
+                               for j in range(num_senders)]
+
+    async def send_tensors(sender_index: int):
+        local_tensors = local_tensors_by_sender[sender_index]
+        averaged_parts = []
+        pending_tasks = []
+
+        for part_index in range(num_parts):
+            pending_tasks.append(asyncio.create_task(
+                reducer.accumulate_part(sender_index, part_index, local_tensors[part_index])))
+
+            if random.random() < synchronize_prob or part_index == num_parts - 1:
+                averaged_parts.extend(await asyncio.gather(*pending_tasks))
+                pending_tasks = []
+        return averaged_parts
+
+    averaged_tensors_by_peer = await asyncio.gather(*map(send_tensors, range(num_senders)))
+
+    reference = [sum(local_tensors_by_sender[sender_index][part_index]
+                     for sender_index in range(num_senders)) / num_senders
+                 for part_index in range(num_parts)]
+
+    for averaged_tensors in averaged_tensors_by_peer:
+        assert len(averaged_tensors) == len(reference)
+        for averaging_result, reference_tensor in zip(averaged_tensors, reference):
+            assert torch.allclose(averaging_result, reference_tensor, rtol=1e-3, atol=1e-5)
+
+
+class AllreduceRunnerForTesting(AllReduceRunner):
+    """ a version of AllReduceRunner that was monkey-patched to accept custom endpoint names """
+    def __init__(self, *args, peer_endpoints, **kwargs):
+        self.__peer_endpoints = peer_endpoints
+        super().__init__(*args, **kwargs)
+
+    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
+        return ChannelCache.get_stub(
+            self.__peer_endpoints[peer], averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+
+
+NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
+
+
+@pytest.mark.parametrize("peer_modes, averaging_weights, peer_fractions", [
+    ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1)),
+    ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1)),
+    ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0)),
+    ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0)),
+    ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4)),
+    ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1)),
+    ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0)),
+    ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4)),
+])
+@pytest.mark.parametrize("part_size_bytes", [2 ** 20, 256, 19],)
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
+    """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
+
+    peers = "alice", "bob", "carol", "colab"
+
+    tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
+                       for i, peer in enumerate(peers)}
+
+    group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
+
+    servers = []
+    allreduce_protocols = []
+    peer_endpoints = {}
+
+    for peer in peers:
+        server = grpc.aio.server()
+        allreduce_protocol = AllreduceRunnerForTesting(
+            group_id=group_id, endpoint=peer, tensors=[x.clone() for x in tensors_by_peer[peer]],
+            ordered_group_endpoints=peers, peer_fractions=peer_fractions, modes=peer_modes,
+            weights=averaging_weights, peer_endpoints=peer_endpoints, part_size_bytes=part_size_bytes
+        )
+        averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(allreduce_protocol, server)
+        peer_endpoints[peer] = f"127.0.0.1:{server.add_insecure_port('127.0.0.1:*')}"
+        allreduce_protocols.append(allreduce_protocol)
+        servers.append(server)
+        await server.start()
+
+    async def _run_allreduce_inplace(allreduce: AllReduceRunner):
+        async for tensor_index, tensor_delta in aenumerate(allreduce):
+            allreduce.tensor_part_container.local_tensors[tensor_index].add_(tensor_delta)
+
+    await asyncio.gather(*map(_run_allreduce_inplace, allreduce_protocols))
+
+    reference_tensors = [sum(tensors_by_peer[peer][i] * averaging_weights[peer_index]
+                             for peer_index, peer in enumerate(peers)) / sum(averaging_weights)
+                         for i in range(len(tensors_by_peer[peers[0]]))]
+
+    for peer_index, protocol in enumerate(allreduce_protocols):
+        assert protocol._future.done()
+        if protocol.modes[peer_index] != AveragingMode.AUX:
+            targets_for_peer = reference_tensors
+        else:
+            targets_for_peer = tensors_by_peer[peers[peer_index]]
+        output_tensors = protocol.tensor_part_container.local_tensors
+        assert len(output_tensors) == len(targets_for_peer)
+        assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
+                   for our, ref in zip(output_tensors, targets_for_peer))
+
+    for server in servers:
+        await server.stop(grace=1)

+ 45 - 67
tests/test_averaging.py

@@ -1,14 +1,13 @@
-import asyncio
 import random
 
 import numpy as np
 import torch
 import pytest
 import hivemind
-from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts, AveragingMode
+from hivemind.client.averaging.allreduce import AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.key_manager import GroupKeyManager
-from hivemind.utils import Endpoint
+from hivemind.proto.runtime_pb2 import CompressionType
 
 
 @pytest.mark.forked
@@ -47,13 +46,13 @@ def _test_allreduce_once(n_clients, n_aux):
     n_peers = 4
     modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
     random.shuffle(modes)
-    
+
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
     peer_tensors = [tensors1, tensors2, tensors3, tensors4]
-    
+
     reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes)
                  if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))]
 
@@ -130,6 +129,47 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
     dht.shutdown()
 
 
+@pytest.mark.forked
+def test_allreduce_compression():
+    """ this test ensures that compression works correctly when multiple tensors have different compression types """
+    dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
+
+    tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
+    tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
+    results = {}
+
+    FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
+
+    for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
+        averager1 = hivemind.DecentralizedAverager([x.clone() for x in tensors1], dht=dht,
+                                                   compression_type=compression_type_pair, listen=False,
+                                                   target_group_size=2, prefix='mygroup', start=True)
+        averager2 = hivemind.DecentralizedAverager([x.clone() for x in tensors2], dht=dht,
+                                                   compression_type=compression_type_pair,
+                                                   target_group_size=2, prefix='mygroup', start=True)
+
+        for future in averager1.step(wait=False), averager2.step(wait=False):
+            future.result()
+
+        with averager1.get_tensors() as averaged_tensors:
+            results[compression_type_pair] = averaged_tensors
+
+    assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
+    assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
+    assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
+    assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
+
+    assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
+    assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
+    assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
+    assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
+
+    reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
+    for i in range(2):
+        assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
+        assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
+
+
 def compute_mean_std(averagers, unbiased=True):
     results = []
     for averager in averagers:
@@ -201,68 +241,6 @@ def test_allgather():
     dht.shutdown()
 
 
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_allreduce_protocol():
-    """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
-    peers = "alice", "bob", "carol", "colab"
-
-    tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
-                       for i, peer in enumerate(peers)}
-
-    group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
-    allreduce_protocols = [AllReduceProtocol(
-        group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer],
-        ordered_group_endpoints=peers, part_sizes=(150, 200, 67, 0))
-        for peer in peers]
-
-    async def _accumulate(sender: Endpoint, recipient: Endpoint):
-        sender_allreduce = allreduce_protocols[peers.index(sender)]
-        recipient_allreduce = allreduce_protocols[peers.index(recipient)]
-        averaged_part = await recipient_allreduce.accumulate_part(
-            source=sender, remote_part=sender_allreduce.local_tensor_parts[recipient])
-        sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part)
-
-    await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers
-                        if recipient != "colab"})
-
-    reference_tensors = [
-        sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
-        for i in range(len(tensors_by_peer[peers[0]]))
-    ]
-
-    for peer, allreduce in zip(peers, allreduce_protocols):
-        assert allreduce.future.done()
-        averaged_tensors = await allreduce
-        assert len(averaged_tensors) == len(reference_tensors)
-        assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
-                   for our, ref in zip(averaged_tensors, reference_tensors))
-
-
-@pytest.mark.forked
-def test_partitioning():
-    for _ in range(100):
-        tensors = []
-        for _ in range(random.randint(1, 5)):
-            ndim = random.randint(0, 4)
-            shape = torch.Size([random.randint(0, 16) for _ in range(ndim)])
-            make_tensor = random.choice([torch.rand, torch.randn, torch.zeros, torch.ones])
-            tensors.append(make_tensor(shape))
-
-        total_size = sum(map(torch.Tensor.numel, tensors))
-        if total_size == 0:
-            continue
-        num_chunks = random.randint(1, min(100, sum(x.numel() for x in tensors)))
-        part_sizes = load_balance_peers(total_size, [None] * num_chunks)
-        chunks = split_into_parts(tensors, part_sizes)
-        assert len(chunks) == num_chunks
-        shapes = [tensor.shape for tensor in tensors]
-        restored = restore_from_parts(chunks, shapes)
-        assert len(restored) == len(tensors)
-        assert all(new.shape == old.shape for new, old in zip(restored, tensors))
-        assert all(torch.allclose(new, old) for new, old in zip(restored, tensors))
-
-
 def get_cost(vector_size, partitions, throughputs):
     return max((vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
                for i in range(len(partitions)))

+ 37 - 1
tests/test_util_modules.py

@@ -11,6 +11,7 @@ from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 import hivemind
 from hivemind.utils import MSGPackSerializer
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
 from hivemind.utils.mpfuture import FutureStateError
 
 
@@ -142,6 +143,7 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     for compression_type in CompressionType.values():
         assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
 
+
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_channel_cache():
@@ -256,7 +258,7 @@ def test_split_parts():
     for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
         with pytest.raises(RuntimeError):
             deserialize_torch_tensor(combined)
-            # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceProtocol
+            # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceRunner
 
 
 def test_generic_data_classes():
@@ -271,3 +273,37 @@ def test_generic_data_classes():
     sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)])
     sorted_heap_entries = sorted([HeapEntry(DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
     assert all([entry.expiration_time == value for entry, value in zip(sorted_heap_entries, sorted_expirations)])
+
+
+@pytest.mark.asyncio
+async def test_asyncio_utils():
+    res = [i async for i, item in aenumerate(aiter('a', 'b', 'c'))]
+    assert res == list(range(len(res)))
+
+    num_steps = 0
+    async for elem in amap_in_executor(lambda x: x ** 2, aiter(*range(100)), max_prefetch=5):
+        assert elem == num_steps ** 2
+        num_steps += 1
+    assert num_steps == 100
+
+    ours = [elem async for elem in amap_in_executor(max, aiter(*range(7)), aiter(*range(-50, 50, 10)), max_prefetch=1)]
+    ref = list(map(max, range(7), range(-50, 50, 10)))
+    assert ours == ref
+
+    ours = [row async for row in azip(aiter('a', 'b', 'c'), aiter(1, 2, 3))]
+    ref = list(zip(['a', 'b', 'c'], [1, 2, 3]))
+    assert ours == ref
+
+    async def _aiterate():
+        yield 'foo'
+        yield 'bar'
+        yield 'baz'
+
+    iterator = _aiterate()
+    assert (await anext(iterator)) == 'foo'
+    tail = [item async for item in iterator]
+    assert tail == ['bar', 'baz']
+    with pytest.raises(StopAsyncIteration):
+        await anext(iterator)
+
+    assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ['foo', 'bar', 'baz'] + list(range(5))