Jelajahi Sumber

Add per-tensor compression, make all-reduce faster and more flexible (#272)

* extract core components of all-reduce as TensorPartContainer / TensorPartition
* ensure that compression happens asynchronously and with background threads (per chunk)
* update AllReduceRunner for new partitioning 
* per-tensor compression in TensorPartContainer
* minimize memory allocation (e.g. use iterator in update_tensors)
* update DecentralizedAverager to use new AllReduceProtocol

Tests:
* test that partitioning recovers the original tensors
* test partitioning edge cases (e.g. empty tensors)
* test that partitioning is indeed asynchronous
* test new all-reduce protocol in separate file
* test asyncio utility functions
* benchmark performance under limited bandwidth (see PR discussion)


Co-authored-by: mponty <heapnhash@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Michael Diskin <yhn112@users.noreply.github.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 4 tahun lalu
induk
melakukan
0a0e290ea3

+ 1 - 1
docs/modules/client.rst

@@ -25,4 +25,4 @@
 .. autoclass:: DecentralizedAverager
 .. autoclass:: DecentralizedAverager
    :members:
    :members:
    :member-order: bysource
    :member-order: bysource
-   :exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part
+   :exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part, register_allreduce_group

+ 55 - 50
hivemind/client/averaging/__init__.py

@@ -20,6 +20,7 @@ import torch
 import numpy as np
 import numpy as np
 
 
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
+from hivemind.client.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
 from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
@@ -34,9 +35,8 @@ from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescripto
 
 
 # flavour types
 # flavour types
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
-DataForGather = Any
+GatheredData = Any
 logger = get_logger(__name__)
 logger = get_logger(__name__)
-DEFAULT_CHUNK_SIZE_BYTES = 2 ** 16
 
 
 
 
 class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
 class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
@@ -61,7 +61,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
-    :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
+    :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
     :param throughput: if specified, this value represents the network bandwidth available to averager.
     :param throughput: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
           By default, the averager is assumed to have the average bandwidth of his group.
           If throughput == 0, averager will rely on its groupmates to do all the averaging.
           If throughput == 0, averager will rely on its groupmates to do all the averaging.
@@ -94,8 +94,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
                  prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
                  prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
-                 allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
+                 averaging_expiration: float = 15, request_timeout: float = 3, averaging_alpha: float = 1.0,
+                 part_size_bytes: int = DEFAULT_PART_SIZE_BYTES, allreduce_timeout: Optional[float] = None,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  throughput: Optional[float] = None, min_vector_size: int = 0,
                  throughput: Optional[float] = None, min_vector_size: int = 0,
                  auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
                  auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
@@ -135,7 +135,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self.matchmaking_kwargs = dict(
         self.matchmaking_kwargs = dict(
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
             min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout)
             min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout)
-        self.allreduce_kwargs = dict(compression_type=compression_type, chunk_size_bytes=chunk_size_bytes,
+        self.allreduce_kwargs = dict(compression_type=compression_type, part_size_bytes=part_size_bytes,
                                      min_vector_size=min_vector_size)
                                      min_vector_size=min_vector_size)
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
@@ -251,8 +251,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         if self._parent_pid != os.getpid() or self.is_alive():
         if self._parent_pid != os.getpid() or self.is_alive():
             self.shutdown()
             self.shutdown()
 
 
-    def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, timeout: Optional[float] = None,
-             allow_retries: bool = True, wait: bool = True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
+    def step(self, gather: Optional[GatheredData] = None, weight: Optional[float] = None,
+             timeout: Optional[float] = None, allow_retries: bool = True, wait: bool = True
+             ) -> Union[Optional[Dict[Endpoint, GatheredData]], MPFuture]:
         """
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
 
@@ -265,10 +266,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
         :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         """
-        if self.mode == AveragingMode.AUX and weight != 1:
+        if self.mode == AveragingMode.AUX and weight is not None:
             logger.warning("Averager is running in auxiliary mode, weight is unused.")
             logger.warning("Averager is running in auxiliary mode, weight is unused.")
-        else:
-            assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
+        if weight is None:
+            weight = float(self.mode != AveragingMode.AUX)
+        assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
 
 
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
@@ -278,9 +280,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
     async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
     async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
                     allow_retries: bool, timeout: Optional[float]):
                     allow_retries: bool, timeout: Optional[float]):
-        loop = asyncio.get_event_loop()
         start_time = get_dht_time()
         start_time = get_dht_time()
-        group_id = None
 
 
         try:
         try:
             while not future.done():
             while not future.done():
@@ -291,16 +291,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                                                                         data_for_gather=data_for_gather)
                                                                         data_for_gather=data_for_gather)
                     if group_info is None:
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
                         raise AllreduceException("Averaging step failed: could not find a group.")
-                    group_id = group_info.group_id
-                    allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
-                    self._running_groups[group_id] = allreduce_runner
-                    self._pending_group_assembled.set()
-                    await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
-                    if self.mode != AveragingMode.AUX:
-                        await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
 
 
-                    # averaging is finished, exit the loop
-                    future.set_result(allreduce_runner.gathered)
+                    future.set_result(await asyncio.wait_for(
+                        self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout))
+                    # averaging is finished, loop will now exit
 
 
                 except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
                 except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
                         asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
                         asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
@@ -311,10 +305,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     else:
                     else:
                         logger.warning(f"Averager caught {repr(e)}, retrying")
                         logger.warning(f"Averager caught {repr(e)}, retrying")
 
 
-                finally:
-                    _ = self._running_groups.pop(group_id, None)
-                    self._pending_group_assembled.set()
-
         except BaseException as e:
         except BaseException as e:
             if not future.done():
             if not future.done():
                 future.set_exception(e)
                 future.set_exception(e)
@@ -324,35 +314,51 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
                 future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
                                                   " Please report this to hivemind issues."))
                                                   " Please report this to hivemind issues."))
 
 
-    async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
-        """ Use a group description found by Matchmaking to form AllreduceRunner """
+    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """ Run All-Reduce in a given group and update tensors in place, return gathered metadata """
         try:
         try:
             weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
             user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
-            # compute optimal part sizes from peer throughputs
             modes = tuple(map(AveragingMode, mode_ids))
             modes = tuple(map(AveragingMode, mode_ids))
-            incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)]  # TODO: replace with proper load balancing
-            part_sizes = await asyncio.get_event_loop().run_in_executor(
+
+            # compute optimal part sizes from peer throughputs; TODO: replace with proper load balancing
+            incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0
+                                    for thr, mode in zip(throughputs, modes)]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
                 None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
                 None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
-            async with self.get_tensors_async() as averaged_tensors:
-                return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
-                                       ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
-                                       weights=weights, gathered=user_gathered, return_deltas=True, modes=modes, **kwargs)
-        except Exception as e:
-            raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {weights, throughputs, modes, user_gathered}")
 
 
-    def update_tensors(self, allreduce_group: AllReduceRunner):
-        """
-        a private (extendable) method that applies changes from a finished allreduce to local tensors
-        """
-        assert allreduce_group.return_deltas and allreduce_group.future.done()
-        averaging_deltas = allreduce_group.future.result()
+            async with self.get_tensors_async() as local_tensors:
+                allreduce = AllReduceRunner(
+                    group_id=group_info.group_id, tensors=local_tensors, endpoint=self.endpoint,
+                    ordered_group_endpoints=group_info.endpoints, peer_fractions=peer_fractions, weights=weights,
+                    gathered=user_gathered, modes=modes, **kwargs)
 
 
-        with torch.no_grad(), self.get_tensors() as local_tensors:
-            assert len(local_tensors) == len(self._averaged_tensors)
-            for tensor, update in zip(local_tensors, averaging_deltas):
-                tensor.add_(update, alpha=self._averaging_alpha)
-        self.last_updated = get_dht_time()
+                with self.register_allreduce_group(group_info.group_id, allreduce):
+
+                    # actually run all-reduce
+                    averaging_outputs = [output async for output in allreduce]
+
+                    if modes[group_info.endpoints.index(self.endpoint)] != AveragingMode.AUX:
+                        assert len(local_tensors) == len(self._averaged_tensors)
+                        for tensor, update in zip(local_tensors, averaging_outputs):
+                            tensor.add_(update, alpha=self._averaging_alpha)
+                        self.last_updated = get_dht_time()
+
+                return allreduce.gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+    @contextlib.contextmanager
+    def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
+        """ registers a given all-reduce runner to listen for incoming connections """
+        try:
+            self._running_groups[group_id] = allreduce
+            self._pending_group_assembled.set()
+            yield
+        finally:
+            self._running_groups.pop(group_id, None)
+            self._pending_group_assembled.set()
 
 
     @contextlib.contextmanager
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
     def get_tensors(self) -> Sequence[torch.Tensor]:
@@ -418,11 +424,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """
         """
         if not self.allow_state_sharing:
         if not self.allow_state_sharing:
             return  # deny request and direct peer to the next prospective averager
             return  # deny request and direct peer to the next prospective averager
-        chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
         metadata, tensors = await self._get_current_state_from_host_process()
         metadata, tensors = await self._get_current_state_from_host_process()
 
 
         for tensor in tensors:
         for tensor in tensors:
-            for part in split_for_streaming(serialize_torch_tensor(tensor), chunk_size_bytes):
+            for part in split_for_streaming(serialize_torch_tensor(tensor)):
                 if metadata is not None:
                 if metadata is not None:
                     yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
                     yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
                     metadata = None
                     metadata = None

+ 159 - 206
hivemind/client/averaging/allreduce.py

@@ -1,14 +1,15 @@
 import asyncio
 import asyncio
-from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any, Optional
+from typing import Sequence, Dict, Tuple, AsyncIterator, Any, Optional
 from enum import Enum
 from enum import Enum
 
 
 import grpc
 import grpc
 import torch
 import torch
 
 
-from hivemind.utils import Endpoint, get_logger, ChannelCache, anext
-from hivemind.utils import split_for_streaming, combine_from_streaming
+from hivemind.client.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
+from hivemind.utils import Endpoint, get_logger, ChannelCache
+from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.proto import averaging_pb2_grpc, runtime_pb2, averaging_pb2
+from hivemind.proto import averaging_pb2_grpc, averaging_pb2
 
 
 # flavour types
 # flavour types
 GroupID = bytes
 GroupID = bytes
@@ -21,256 +22,208 @@ class AveragingMode(Enum):
     AUX = 2
     AUX = 2
 
 
 
 
-class AllReduceProtocol:
+class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
     """
     """
     An internal class that runs butterfly AllReduce in a predefined group of averagers
     An internal class that runs butterfly AllReduce in a predefined group of averagers
 
 
+    :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
+    :param group_id: unique identifier of this specific all-reduce run
+    :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param endpoint: your endpoint, must be included in ordered_group_endpoints
     :param endpoint: your endpoint, must be included in ordered_group_endpoints
     :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
     :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
-    :param part_sizes: for each peer, a number of vector elements that this peer is responsible for averaging
-    :param return_deltas: if True, returns the element-wise differences (averaged_tensors - original_tensors)
-           default (False) - return averaged_tensors by themselves
+    :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
+      (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
+    :param modes: AveragingMode for each peer in ordered_group_endpoints (normal, client-only or auxiliary)
+    :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
+    :param gathered: additional user-defined data collected from this group
+    :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
     """
     """
 
 
-    def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False,
-                 modes: Optional[Sequence[AveragingMode]] = None):
+    def __init__(
+            self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
+            ordered_group_endpoints: Sequence[Endpoint], peer_fractions: Tuple[float, ...],
+            weights: Optional[Sequence[float]] = None, modes: Optional[Sequence[AveragingMode]] = None,
+            gathered: Optional[Dict[Endpoint, Any]] = None, **kwargs):
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
-        self.group_id, self.endpoint = group_id, endpoint
-        self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
-        if modes is None:
-            modes = [AveragingMode.CLIENT if part_size == 0 else AveragingMode.NODE for part_size in part_sizes]
-        assert any(mode != AveragingMode.CLIENT for mode in modes), "Cannot run allreduce without reducers."
-        self.peer_modes = dict(zip(ordered_group_endpoints, modes))
-
-        self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
-        self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
-        self.return_deltas = return_deltas
-
-        self.accumulator = torch.zeros_like(self.local_tensor_parts[self.endpoint])
-        self.denominator = 0.0  # number of peers added to accumulator or sum of their weights
-        self.accumulated_from: Set[Endpoint] = set()  # peers that we have accumulated our part from
-        self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
-        self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
-        self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
-
-        self.num_senders = len([mode for mode in modes if mode != AveragingMode.AUX])
-
-        if self.num_senders == 0:
-            self.future.set_result(None)
-        for endpoint, mode in self.peer_modes.items():
-            if mode == AveragingMode.CLIENT:
-                self.averaged_tensor_parts[endpoint] = torch.tensor([])
+        modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
+        weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
+        assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
+        assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
+        for mode, frac, weight in zip(modes, peer_fractions, weights):
+            assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
+            assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
+
+        self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
+        self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
+
+        self._future = asyncio.Future()
+
+        self.sender_endpoints, self.sender_weights = [], []
+        for endpoint, weight, mode in zip(self.ordered_group_endpoints, weights, modes):
+            if mode != AveragingMode.AUX:
+                self.sender_endpoints.append(endpoint)
+                self.sender_weights.append(weight)
+
+        endpoint_index = self.ordered_group_endpoints.index(self.endpoint)
+        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
+        self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(endpoint_index)
+        self.tensor_part_reducer = TensorPartReducer(tuple(part.shape for part in self.parts_for_local_averaging),
+                                                     len(self.sender_endpoints), self.sender_weights)
 
 
     def __repr__(self):
     def __repr__(self):
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
 
 
-    def __await__(self):
-        return self.future.__await__()
+    def __aiter__(self):
+        return self.run()
 
 
     def __contains__(self, endpoint: Endpoint):
     def __contains__(self, endpoint: Endpoint):
-        return endpoint in self.local_tensor_parts
+        return endpoint in self.ordered_group_endpoints
 
 
     @property
     @property
     def group_size(self):
     def group_size(self):
         return len(self.ordered_group_endpoints)
         return len(self.ordered_group_endpoints)
 
 
-    async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor, weight: float = 1.0) -> torch.Tensor:
-        """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
-        assert not self.averaged_part.done(), f"already finished averaging part: {self.averaged_part}"
-        assert not self.future.done(), f"already finished allreduce: {self.future}"
-        assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
-        assert source not in self.accumulated_from, "duplicate source, already received that part"
-        assert self.peer_modes[self.endpoint] != AveragingMode.CLIENT, f"{self.endpoint} is in AveragingMode.client mode"
-        assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float"
-
-        logger.debug(f"{self} - accumulating tensor part from {source}")
-        self.accumulator.add_(remote_part, alpha=weight)
-        self.denominator += weight
-        self.accumulated_from.add(source)
-
-        assert len(self.accumulated_from) <= self.num_senders
-        if len(self.accumulated_from) == self.num_senders:
-            average_result = self.accumulator.div_(self.denominator)
-            self.averaged_part.set_result(average_result)
-
-            if self.peer_modes[self.endpoint] == AveragingMode.AUX:
-                self.future.set_result(None)  # auxiliary mode has finished averaging
-            else:
-                self.register_averaged_part(self.endpoint, average_result)
-
-        return await self.averaged_part
-
-    def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
-        assert not self.future.done(), f"already finished allreduce: {self.future}"
-        assert source in self.local_tensor_parts, "the provider of averaged part is not from my group"
-        assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
-        assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
-        assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
-        assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers do not have local tensors for sending"
-        logger.debug(f"{self} - receiving averaged tensor part from {source}")
-        self.averaged_tensor_parts[source] = averaged_part
-        if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
-            ordered_averaged_parts = [self.averaged_tensor_parts[endpoint] for endpoint in self.ordered_group_endpoints]
-            outputs = restore_from_parts(ordered_averaged_parts, self.tensor_shapes)
-
-            if self.return_deltas:
-                local_parts = [self.local_tensor_parts[peer] for peer in self.ordered_group_endpoints]
-                with torch.no_grad():
-                    original_tensors = restore_from_parts(local_parts, self.tensor_shapes)
-                    for averaged_tensor, original_tensor in zip(outputs, original_tensors):
-                        averaged_tensor -= original_tensor
-
-            self.future.set_result(outputs)
-
-    def cancel(self) -> bool:
-        if not self.future.done():
-            logger.debug(f"{self} - cancelled")
-            self.future.cancel()
-            if not self.averaged_part.done():
-                self.averaged_part.cancel()
-            return True
-        else:
-            logger.debug(f"{self} - failed to cancel, allreduce is already finished: {self.future}")
-            return False
-
-    def set_exception(self, exception: Exception) -> bool:
-        if not self.future.done():
-            logger.debug(f"{self} - {exception}")
-            self.future.set_exception(exception)
-            if not self.averaged_part.done():
-                self.averaged_part.cancel()
-            return True
-        else:
-            logger.debug(f"{self} - failed to set {exception}, allreduce already finished: {self.future}")
-            return False
-
-
-class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragingServicer):
-    """
-    A class that implements ButterflyAllReduceProtocol on top of a gRPC servicer
-    """
-
-    def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
-                 chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...],
-                 gathered: Dict[Endpoint, Any], return_deltas: bool = False, **kwargs):
-        super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
-                         ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas, **kwargs)
-        self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
-        self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))
-
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
 
 
-    async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
-        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
-        assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors"
-        if peer_endpoint == self.endpoint:
-            return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint])
-        serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
-        chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes)
-
-        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-        await stream.write(averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id,
-                                                       endpoint=self.endpoint, tensor_part=next(chunks)))
-        for chunk in chunks:
-            await stream.write(averaging_pb2.AveragingData(tensor_part=chunk))
-        await stream.done_writing()
-
-        outputs: Sequence[averaging_pb2.AveragingData] = [message async for message in stream]
-        code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR
-        if code != averaging_pb2.AVERAGED_PART:
-            raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}"
-                                     f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
-                                     f" allreduce failed")
-
+    async def run(self) -> AsyncIterator[torch.Tensor]:
+        """ Run all-reduce, return differences between averaged and original tensors as they are computed """
+        pending_tasks = set()
         try:
         try:
-            averaged_part = local_part + deserialize_torch_tensor(combine_from_streaming(
-                [message.tensor_part for message in outputs]))
-        except RuntimeError as e:
-            raise AllreduceException(f"Could not deserialize averaged part from {peer_endpoint}: {e}")
+            if len(self.sender_endpoints) == 0:
+                logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
+                self.finalize()
 
 
-        self.register_averaged_part(peer_endpoint, averaged_part)
-        return averaged_part
+            elif self.endpoint in self.sender_endpoints:
+                for endpoint, parts in zip(self.ordered_group_endpoints, self.tensor_part_container.num_parts_by_peer):
+                    if parts != 0:
+                        pending_tasks.add(asyncio.create_task(self._communicate_with_peer(endpoint)))
 
 
-    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
-        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
-        await stream.done_writing()
+                async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
+                    yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
+                self.finalize()
+
+            else:  # auxiliary peer
+                await self.tensor_part_reducer.finished.wait()
+                self.finalize()
 
 
-    async def run(self) -> Sequence[torch.Tensor]:
-        """
-        send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
-        """
-        try:
-            if self.peer_modes[self.endpoint] != AveragingMode.AUX:
-                await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
-                                            for i, peer in enumerate(self.ordered_group_endpoints)
-                                            if self.peer_modes[peer] != AveragingMode.CLIENT))
-            return await self
         except BaseException as e:
         except BaseException as e:
+            self.finalize(exception=e)
+            for task in pending_tasks:
+                task.cancel()
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            self.set_exception(e)
-            for peer_endpoint, part_size in zip(self.ordered_group_endpoints, self.part_sizes):
-                if peer_endpoint != self.endpoint and part_size > 0:
+            for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
+                if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
                     asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
                     asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
             raise
             raise
 
 
-    async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
-                                        ) -> Iterable[runtime_pb2.Tensor]:
-        """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
-        try:
-            tensor_part = deserialize_torch_tensor(combine_from_streaming(stream_messages))
-        except RuntimeError as e:
-            raise AllreduceException(f"Could not deserialize tensor part from {source} for streaming {e}")
+    async def _communicate_with_peer(self, peer_endpoint: Endpoint):
+        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
+        peer_index = self.ordered_group_endpoints.index(peer_endpoint)
+        if peer_endpoint == self.endpoint:
+            sender_index = self.sender_endpoints.index(peer_endpoint)
+            for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
+                averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+                self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
 
-        averaged_part = await self.accumulate_part(source, tensor_part, weight=self.peer_weights[source])
-        serialized_tensor = serialize_torch_tensor(averaged_part - tensor_part, self.compression_type, allow_inplace=False)
-        stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
-        return stream_chunks
+        else:
+            loop = asyncio.get_event_loop()
+            stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
+            write_task = asyncio.create_task(self._write_to_peer(stream, peer_index))
+
+            try:
+                code = None
+                async for part_index, msg in aenumerate(stream):
+                    if code is None:
+                        code = msg.code
+                    averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
+                    self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
+                await write_task
+
+                if code != averaging_pb2.AVERAGED_PART:
+                    raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
+                                             f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
+                                             f", allreduce failed")
+            finally:
+                if not write_task.done():
+                    write_task.cancel()
+
+    async def _write_to_peer(self, stream: grpc.aio.StreamStreamCall, peer_index: int):
+        parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
+        first_part = await anext(parts_aiter)
+        await stream.write(averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING,
+                                                       group_id=self.group_id, endpoint=self.endpoint,
+                                                       tensor_part=first_part))
+        async for part in parts_aiter:
+            await stream.write(averaging_pb2.AveragingData(tensor_part=part))
+
+        await stream.done_writing()
 
 
     async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
     async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
                                  ) -> AsyncIterator[averaging_pb2.AveragingData]:
                                  ) -> AsyncIterator[averaging_pb2.AveragingData]:
-        """ a groupmate sends us a part of his tensor; we should average it with other peers and return the delta"""
+        """ a peer sends us a part of his tensor; we should average it with other peers and return the difference """
         request: averaging_pb2.AveragingData = await anext(stream)
         request: averaging_pb2.AveragingData = await anext(stream)
-
-        if request.group_id != self.group_id:
-            yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+        reason_to_reject = self._check_reasons_to_reject(request)
+        if reason_to_reject:
+            yield reason_to_reject
+            return
 
 
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
             try:
-                tensor_chunks = (request.tensor_part, *[msg.tensor_part async for msg in stream])
-                averaged_chunks = iter(await self.accumulate_part_streaming(request.endpoint, tensor_chunks))
-                yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=next(averaged_chunks))
-                for averaged_chunk in averaged_chunks:
-                    yield averaging_pb2.AveragingData(tensor_part=averaged_chunk)
+                sender_index = self.sender_endpoints.index(request.endpoint)
+                async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
+                    yield msg
 
 
             except Exception as e:
             except Exception as e:
-                self.set_exception(e)
+                self.finalize(exception=e)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
         else:
         else:
             error_code = averaging_pb2.MessageCode.Name(request.code)
             error_code = averaging_pb2.MessageCode.Name(request.code)
             logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
             logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
-            self.set_exception(AllreduceException(f"peer {request.endpoint} sent {error_code}."))
+            self.finalize(exception=AllreduceException(f"peer {request.endpoint} sent {error_code}."))
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
 
 
+    def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
+        if request.group_id != self.group_id:
+            return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+        elif self._future.cancelled():
+            return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
+        elif self._future.done():
+            return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+
+    async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
+        loop = asyncio.get_event_loop()
+        async for part_index, (tensor_part, part_compression) in aenumerate(
+                amap_in_executor(lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression), stream,
+                                 max_prefetch=self.tensor_part_container.prefetch)):
+            averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+
+            serialized_delta = await loop.run_in_executor(
+                None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression))
+            yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
 
-def split_into_parts(tensors: Sequence[torch.Tensor], part_sizes: Tuple[int, ...]) -> Tuple[torch.Tensor, ...]:
-    """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
-    flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
-    return torch.split_with_sizes(flat_tensor, part_sizes, dim=0)
-
-
-def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:
-    """ restores the original tensor shapes from chunks obtained by split_into_chunks """
-    flat_tensor = torch.cat(tuple(chunks))
-    result_sizes = tuple(map(torch.Size.numel, shapes))
-    flat_original_tensors = torch.split_with_sizes(flat_tensor, result_sizes)
-    return tuple(map(torch.Tensor.reshape, flat_original_tensors, shapes))
-
+    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
+        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
+        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
+        await stream.done_writing()
 
 
-class AllreduceException(Exception):
-    """ A special exception that is raised when allreduce can't continue normally (e.g. disbanded/bad request/etc) """
+    def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
+        assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
+        if not self._future.done():
+            if cancel:
+                logger.debug(f"{self} - cancelled")
+                self._future.cancel()
+            elif exception:
+                logger.debug(f"{self} - caught {exception}")
+                self._future.set_exception(exception)
+            else:
+                logger.debug(f"{self} - finished")
+                self._future.set_result(None)
+            self.tensor_part_container.finalize()
+            self.tensor_part_reducer.finalize()
+            return True
+        else:
+            logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
+            return False

+ 1 - 0
hivemind/client/averaging/load_balancing.py

@@ -28,6 +28,7 @@ def load_balance_peers(vector_size, throughputs: Sequence[Optional[float]], min_
         assert not all(throughput == 0 for throughput in throughputs), "Must have at least one nonzero throughput"
         assert not all(throughput == 0 for throughput in throughputs), "Must have at least one nonzero throughput"
         scores = np.asarray([1.0 if throughput is None else 0.0 for throughput in throughputs])
         scores = np.asarray([1.0 if throughput is None else 0.0 for throughput in throughputs])
 
 
+    #TODO(jheuristic) we no longer need hagenbach-bishoff with new AllReduceRunner
     return tuple(hagenbach_bishoff(vector_size, scores))
     return tuple(hagenbach_bishoff(vector_size, scores))
 
 
 
 

+ 224 - 0
hivemind/client/averaging/partition.py

@@ -0,0 +1,224 @@
+"""
+Auxiliary data structures for AllReduceRunner
+"""
+import asyncio
+from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator
+from collections import deque
+
+import torch
+import numpy as np
+
+from hivemind.proto.runtime_pb2 import CompressionType, Tensor
+from hivemind.utils.compression import serialize_torch_tensor, get_nbytes_per_value
+from hivemind.utils.asyncio import amap_in_executor
+
+
+T = TypeVar('T')
+DEFAULT_PART_SIZE_BYTES = 2 ** 20
+
+
+class TensorPartContainer:
+    """
+    Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
+    The class is designed to avoid excessive memory allocation and run all heavy computation in background
+    :param tensors: local tensors to be split and aggregated
+    :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
+    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
+    :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
+    :param prefetch: when compressing, pre-compute this many compressed tensors in background
+    """
+
+    def __init__(self, tensors: Sequence[torch.Tensor], peer_fractions: Sequence[float],
+                 compression_type: Union[type(CompressionType), Sequence[type(CompressionType)]] = CompressionType.NONE,
+                 part_size_bytes: int = 2 ** 20, prefetch: int = 1):
+        if not isinstance(compression_type, Sequence):
+            compression_type = [compression_type] * len(tensors)
+        assert len(compression_type) == len(tensors), "compression types do not match the number of tensors"
+        self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
+        self.compression_type, self.part_size_bytes, self.prefetch = compression_type, part_size_bytes, prefetch
+        self.total_size = sum(tensor.numel() for tensor in tensors)
+        self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
+        self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
+        self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
+        self._output_part_available = [asyncio.Event() for _ in range(self.group_size)]
+        self._outputs_registered_by_peer = [0 for _ in range(self.group_size)]
+        self._outputs_consumed = False
+        self.finished = asyncio.Event()
+        self.num_parts_by_tensor = []
+
+        # split tensor parts in proportion to target_size_by_peer
+        current_length = 0
+        current_peer_index = 0
+        pivots = (np.cumsum(peer_fractions) / np.sum(peer_fractions) * self.total_size).astype(np.int64)
+        pivots[-1] = self.total_size
+
+        for tensor, tensor_compression in zip(self.local_tensors, compression_type):
+            part_size_values = int(part_size_bytes / get_nbytes_per_value(tensor.dtype, tensor_compression))
+            tensor_parts = tensor.detach().view(-1).split(part_size_values)
+            self.num_parts_by_tensor.append(len(tensor_parts))
+            for part in tensor_parts:
+                if current_length + len(part) > pivots[current_peer_index]:
+                    # switch to next peer; if a part lands between parts of two or
+                    # more peers, assign that part to the peer with highest intersection
+                    prev_peer_index = current_peer_index
+                    peer_intersections = [pivots[current_peer_index] - current_length]
+                    while current_length + len(part) > pivots[current_peer_index]:
+                        current_peer_index += 1
+                        current_peer_part_end = min(current_length + len(part), pivots[current_peer_index])
+                        peer_intersections.append(current_peer_part_end - pivots[current_peer_index - 1])
+                    assigned_peer_index = prev_peer_index + np.argmax(peer_intersections)
+                    self._input_parts_by_peer[assigned_peer_index].append((part, tensor_compression))
+                else:
+                    self._input_parts_by_peer[current_peer_index].append((part, tensor_compression))
+                current_length += len(part)
+
+        assert current_length == self.total_size
+        self.num_parts_by_peer = tuple(len(parts) for parts in self._input_parts_by_peer)
+
+    @torch.no_grad()
+    def get_raw_input_parts(self, peer_index: int) -> Tuple[torch.Tensor, ...]:
+        """ get non-serialized tensor parts for a peer at a given index """
+        assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
+        self._inputs_consumed_by_peer[peer_index] = True
+        input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
+        self._input_parts_by_peer[peer_index].clear()
+        return input_parts
+
+    @torch.no_grad()
+    async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[Tensor]:
+        """ iterate serialized tensor parts for a peer at a given index. Run serialization in background. """
+        assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
+        self._inputs_consumed_by_peer[peer_index] = True
+
+        async def _aiterate_parts():
+            for _ in range(self.num_parts_by_peer[peer_index]):
+                yield self._input_parts_by_peer[peer_index].popleft()
+
+        async for serialized_part in amap_in_executor(lambda x_and_compr: serialize_torch_tensor(*x_and_compr),
+                                                      _aiterate_parts(), max_prefetch=self.prefetch):
+            yield serialized_part
+
+    def register_processed_part(self, peer_index: int, part_index: int, part: torch.Tensor):
+        """
+        register next-in-line part of results received from a given peer for use in iterate_output_tensors
+        depending on the algorithm, processed part is an average, difference from average or another aggregation
+        """
+        if part_index != self._outputs_registered_by_peer[peer_index]:
+            raise ValueError(f"Could not register part #{part_index} from peer #{peer_index}, "
+                             f" expected part index: {self._outputs_registered_by_peer[peer_index]}")
+        self._output_parts_by_peer[peer_index].append(part)
+        self._outputs_registered_by_peer[peer_index] += 1
+        self._output_part_available[peer_index].set()
+
+    async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
+        """ iterate over the outputs of averaging (whether they are average, delta or other aggregation result) """
+        assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
+        self._outputs_consumed = True
+        peer_index = num_parts_processed = 0
+        for tensor_index in range(len(self.local_tensors)):
+            tensor_parts = []
+            while len(tensor_parts) < self.num_parts_by_tensor[tensor_index]:
+                if num_parts_processed >= self.num_parts_by_peer[peer_index]:
+                    num_parts_processed = 0
+                    peer_index += 1
+                    continue
+                if not self._output_parts_by_peer[peer_index]:
+                    self._output_part_available[peer_index].clear()
+                    await self._output_part_available[peer_index].wait()
+                    if self.finished.is_set():
+                        raise AllreduceException("All-reduce was terminated during iteration.")
+
+                tensor_parts.append(self._output_parts_by_peer[peer_index].popleft())
+                num_parts_processed += 1
+            tensor = torch.cat(tensor_parts)
+            del tensor_parts
+            yield tensor.reshape(self.local_tensors[tensor_index].shape)
+
+    def __del__(self):
+        self.finalize()
+
+    def finalize(self):
+        """ terminate all iterators, delete intermediate data """
+        if not self.finished.is_set():
+            for peer_index in range(self.group_size):
+                self._inputs_consumed_by_peer[peer_index] = True
+                self._input_parts_by_peer[peer_index].clear()
+                self._output_parts_by_peer[peer_index].clear()
+                self._output_part_available[peer_index].set()
+            self._outputs_consumed = True
+            self.finished.set()
+
+
+class TensorPartReducer:
+    """
+    Auxiliary data structure responsible for running asynchronous all-reduce
+    :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
+    :param num_senders: total number of peers in a given all-reduce group that will send gradients
+    :param weights: relative importance of each sender, used for weighted average (default = equal weights)
+    :note: even if local peer is not sending data, local parts will be used for shape information
+    """
+
+    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int,
+                 weights: Optional[Sequence[float]] = None):
+        self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
+        self.weights = tuple(weights or (1 for _ in range(num_senders)))
+        assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
+        assert all(isinstance(weight, (int, float)) for weight in self.weights)
+        self.current_part_index = -1  # index in local_parts of the part that should be loaded next
+        self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
+        self.accumulator = None  # this will contain the sum of current tensor part from group peers
+        self.denominator = 0.0  # total weight accumulated from all peers for current part
+        self.current_part_future = asyncio.Future()
+        self.finished = asyncio.Event()
+        self.reset_accumulators()
+
+    def reset_accumulators(self):
+        """ (re)create averaging buffers for the next part in line, prepopulate with local tensor part """
+        assert self.current_part_accumulated_from == self.num_senders or self.current_part_index == -1
+        if self.current_part_index >= self.num_parts - 1:
+            self.finalize()
+            return
+
+        self.current_part_index += 1
+        self.current_part_accumulated_from = 0
+        self.current_part_future = asyncio.Future()
+        self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
+        self.denominator = 0.0
+
+    async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
+        """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
+        assert 0 <= sender_index < self.num_senders, "invalid sender index"
+        assert 0 <= part_index < self.num_parts, "invalid part index"
+
+        while part_index > self.current_part_index:
+            # wait for previous parts to finish processing ...
+            await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
+            if self.finished.is_set():
+                raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")
+        assert part_index == self.current_part_index
+
+        current_part_future = self.current_part_future
+
+        self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
+        self.denominator += self.weights[sender_index]
+        self.current_part_accumulated_from += 1
+
+        assert self.current_part_accumulated_from <= self.num_senders
+        if self.current_part_accumulated_from == self.num_senders:
+            current_part_future.set_result(self.accumulator.div_(self.denominator))
+            self.reset_accumulators()
+        return await current_part_future
+
+    def finalize(self):
+        if not self.finished.is_set():
+            if hasattr(self, 'current_part_future'):
+                self.current_part_future.cancel()
+                del self.accumulator
+            self.finished.set()
+
+    def __del__(self):
+        self.finalize()
+
+
+class AllreduceException(Exception):
+    """ A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error) """

+ 1 - 1
hivemind/proto/averaging.proto

@@ -43,7 +43,7 @@ message MessageFromLeader {
   bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
   bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
   string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
   string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
   repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
   repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
-  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their endoints
+  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their endpoints
 }
 }
 
 
 message AveragingData {
 message AveragingData {

+ 49 - 1
hivemind/utils/asyncio.py

@@ -1,7 +1,14 @@
-from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable
+from concurrent.futures import ThreadPoolExecutor
+from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable, Tuple, Optional, Callable
 import asyncio
 import asyncio
+
 import uvloop
 import uvloop
+
+from hivemind.utils.logging import get_logger
+
+
 T = TypeVar('T')
 T = TypeVar('T')
+logger = get_logger(__name__)
 
 
 
 
 def switch_to_uvloop() -> asyncio.AbstractEventLoop:
 def switch_to_uvloop() -> asyncio.AbstractEventLoop:
@@ -27,6 +34,16 @@ async def aiter(*args: T) -> AsyncIterator[T]:
         yield arg
         yield arg
 
 
 
 
+async def azip(*iterables: AsyncIterable[T]) -> AsyncIterator[Tuple[T, ...]]:
+    """ equivalent of zip for asynchronous iterables """
+    iterators = [iterable.__aiter__() for iterable in iterables]
+    while True:
+        try:
+            yield tuple(await asyncio.gather(*(itr.__anext__() for itr in iterators)))
+        except StopAsyncIteration:
+            break
+
+
 async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
 async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
     """ equivalent to chain(iter1, iter2, ...) for asynchronous iterators. """
     """ equivalent to chain(iter1, iter2, ...) for asynchronous iterators. """
     for aiter in async_iters:
     for aiter in async_iters:
@@ -34,6 +51,14 @@ async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
             yield elem
             yield elem
 
 
 
 
+async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]]:
+    """ equivalent to enumerate(iter) for asynchronous iterators. """
+    index = 0
+    async for elem in aiterable:
+        yield index, elem
+        index += 1
+
+
 async def await_cancelled(awaitable: Awaitable) -> bool:
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
     try:
         await awaitable
         await awaitable
@@ -42,3 +67,26 @@ async def await_cancelled(awaitable: Awaitable) -> bool:
         return True
         return True
     except BaseException:
     except BaseException:
         return False
         return False
+
+
+async def amap_in_executor(func: Callable[..., T], *iterables: AsyncIterable, max_prefetch: Optional[int] = None,
+                           executor: Optional[ThreadPoolExecutor] = None) -> AsyncIterator[T]:
+    """ iterate from an async iterable in a background thread, yield results to async iterable """
+    loop = asyncio.get_event_loop()
+    queue = asyncio.Queue(max_prefetch)
+
+    async def _put_items():
+        async for args in azip(*iterables):
+            await queue.put(loop.run_in_executor(executor, func, *args))
+        await queue.put(None)
+
+    task = asyncio.create_task(_put_items())
+    try:
+        future = await queue.get()
+        while future is not None:
+            yield await future
+            future = await queue.get()
+        await task
+    finally:
+        if not task.done():
+            task.cancel()

+ 12 - 0
hivemind/utils/compression.py

@@ -188,3 +188,15 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
 
 
     tensor.requires_grad_(serialized_tensor.requires_grad)
     tensor.requires_grad_(serialized_tensor.requires_grad)
     return tensor
     return tensor
+
+
+def get_nbytes_per_value(dtype: torch.dtype, compression: CompressionType) -> int:
+    """ returns the number of bytes per value for a given tensor (excluding metadata) """
+    if compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
+        return 1
+    elif compression in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT):
+        return 2
+    elif compression == CompressionType.NONE:
+        return torch.finfo(dtype).bits // 8
+    else:
+        raise NotImplementedError(f"Unknown compression type: {CompressionType.Name(compression)}")

+ 5 - 1
hivemind/utils/grpc.py

@@ -158,7 +158,11 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
         raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
         raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
 
 
 
 
-def split_for_streaming(serialized_tensor: runtime_pb2.Tensor, chunk_size_bytes: int) -> Iterator[runtime_pb2.Tensor]:
+STREAMING_CHUNK_SIZE_BYTES = 2 ** 16
+
+
+def split_for_streaming(serialized_tensor: runtime_pb2.Tensor, chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
+                        ) -> Iterator[runtime_pb2.Tensor]:
     """ Split serialized_tensor into multiple chunks for gRPC streaming """
     """ Split serialized_tensor into multiple chunks for gRPC streaming """
     buffer = memoryview(serialized_tensor.buffer)
     buffer = memoryview(serialized_tensor.buffer)
     num_chunks = len(range(0, len(buffer), chunk_size_bytes))
     num_chunks = len(range(0, len(buffer), chunk_size_bytes))

+ 217 - 0
tests/test_allreduce.py

@@ -0,0 +1,217 @@
+import asyncio
+import random
+import time
+from typing import Sequence
+
+import pytest
+import torch
+import grpc
+
+from hivemind import aenumerate, Endpoint
+from hivemind.client.averaging.allreduce import AllReduceRunner, AveragingMode
+from hivemind.client.averaging.partition import TensorPartContainer, TensorPartReducer
+from hivemind.utils import deserialize_torch_tensor, ChannelCache
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.proto import averaging_pb2_grpc
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_partitioning():
+    all_tensors = [
+        torch.randn(30_000, 128), torch.rand(128), torch.ones(1, 1, 1, 1, 1, 1, 8),
+        torch.ones(1, 0), torch.zeros(0), torch.zeros([]), torch.randn(65536),
+        torch.rand(512, 2048), torch.randn(1024, 1024).add(-9), torch.zeros(1020), torch.randn(4096)
+    ]
+
+    # note: this test does _not_ use parameterization to reuse sampled tensors
+    for num_tensors in 1, 3, 5:
+        for part_size_bytes in 31337, 2 ** 20, 10 ** 10:
+            for weights in [(1, 1), (0.333, 0.1667, 0.5003), (1.0, 0.0), [0.0, 0.4, 0.6, 0.0]]:
+                tensors = random.choices(all_tensors, k=num_tensors)
+                partition = TensorPartContainer(tensors, weights, part_size_bytes=part_size_bytes)
+
+                async def write_tensors():
+                    for peer_index in range(partition.group_size):
+                        async for part_index, part in aenumerate(partition.iterate_input_parts_for(peer_index)):
+                            output_tensor = torch.sin(deserialize_torch_tensor(part))
+                            partition.register_processed_part(peer_index, part_index, output_tensor)
+
+                task = asyncio.create_task(write_tensors())
+                tensor_index = 0
+                async for output_tensor in partition.iterate_output_tensors():
+                    assert torch.allclose(output_tensor, torch.sin(tensors[tensor_index]))
+                    tensor_index += 1
+                assert tensor_index == len(tensors)
+                await task
+
+
+@pytest.mark.parametrize("tensors", [[torch.zeros(0)], [torch.zeros(0), torch.zeros(0), torch.zeros(1)],
+                                     [torch.zeros(0), torch.zeros(999), torch.zeros(0), torch.zeros(0)]])
+@pytest.mark.parametrize("peer_fractions", [(0.33, 0.44, 0.23), (0.5, 0.5), (0.1, 0.0, 0.9), (1.0,), (0.1,) * 9])
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_partitioning_edge_cases(tensors: Sequence[torch.Tensor], peer_fractions: Sequence[float]):
+    partition = TensorPartContainer(tensors, peer_fractions, part_size_bytes=16)
+    for peer_index in range(len(peer_fractions)):
+        async for part_index, part in aenumerate(partition.iterate_input_parts_for(peer_index)):
+            partition.register_processed_part(peer_index, part_index, deserialize_torch_tensor(part))
+
+    tensor_index = 0
+    async for output_tensor in partition.iterate_output_tensors():
+        assert torch.allclose(output_tensor, tensors[tensor_index])
+        tensor_index += 1
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_partitioning_asynchronous():
+    """ ensure that tensor partitioning does not interfere with asynchronous code """
+    tensors = [torch.randn(2048, 2048), torch.randn(1024, 4096),
+               torch.randn(4096, 1024), torch.randn(30_000, 1024)]
+    peer_fractions = [0.4, 0.3, 0.2, 0.1]
+
+    partition = TensorPartContainer(tensors, peer_fractions, compression_type=CompressionType.QUANTILE_8BIT)
+    read_started, read_finished = asyncio.Event(), asyncio.Event()
+
+    async def write_tensors():
+        for peer_index in range(partition.group_size):
+            async for part_index, part in aenumerate(partition.iterate_input_parts_for(peer_index)):
+                partition.register_processed_part(peer_index, part_index, deserialize_torch_tensor(part))
+        assert read_started.is_set(), "partitioner should have started reading before it finished writing"
+
+    async def read_tensors():
+        async for _ in partition.iterate_output_tensors():
+            read_started.set()
+        read_finished.set()
+
+    async def wait_synchronously():
+        time_in_waiting = 0.0
+        while not read_finished.is_set():
+            await asyncio.sleep(0.01)
+            time_in_waiting += 0.01
+        return time_in_waiting
+
+    start_time = time.perf_counter()
+    *_, time_in_waiting = await asyncio.gather(write_tensors(), read_tensors(), wait_synchronously())
+    wall_time = time.perf_counter() - start_time
+    # check that event loop had enough time to respond to incoming requests; this is over 50% most of the time
+    # we set 33% threshold to ensure that the test will pass reliably. If we break prefetch, this drops to <10%
+    assert time_in_waiting > wall_time / 3, f"Event loop could only run {time_in_waiting / wall_time :.5f} of the time"
+
+
+@pytest.mark.parametrize("num_senders", [1, 2, 4, 10])
+@pytest.mark.parametrize("num_parts", [0, 1, 100])
+@pytest.mark.parametrize("synchronize_prob", [1.0, 0.1, 0.0])
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float):
+    tensor_part_shapes = [torch.Size([i]) for i in range(num_parts)]
+    reducer = TensorPartReducer(tensor_part_shapes, num_senders)
+
+    local_tensors_by_sender = [[torch.randn(i) for i in range(num_parts)]
+                               for j in range(num_senders)]
+
+    async def send_tensors(sender_index: int):
+        local_tensors = local_tensors_by_sender[sender_index]
+        averaged_parts = []
+        pending_tasks = []
+
+        for part_index in range(num_parts):
+            pending_tasks.append(asyncio.create_task(
+                reducer.accumulate_part(sender_index, part_index, local_tensors[part_index])))
+
+            if random.random() < synchronize_prob or part_index == num_parts - 1:
+                averaged_parts.extend(await asyncio.gather(*pending_tasks))
+                pending_tasks = []
+        return averaged_parts
+
+    averaged_tensors_by_peer = await asyncio.gather(*map(send_tensors, range(num_senders)))
+
+    reference = [sum(local_tensors_by_sender[sender_index][part_index]
+                     for sender_index in range(num_senders)) / num_senders
+                 for part_index in range(num_parts)]
+
+    for averaged_tensors in averaged_tensors_by_peer:
+        assert len(averaged_tensors) == len(reference)
+        for averaging_result, reference_tensor in zip(averaged_tensors, reference):
+            assert torch.allclose(averaging_result, reference_tensor, rtol=1e-3, atol=1e-5)
+
+
+class AllreduceRunnerForTesting(AllReduceRunner):
+    """ a version of AllReduceRunner that was monkey-patched to accept custom endpoint names """
+    def __init__(self, *args, peer_endpoints, **kwargs):
+        self.__peer_endpoints = peer_endpoints
+        super().__init__(*args, **kwargs)
+
+    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
+        return ChannelCache.get_stub(
+            self.__peer_endpoints[peer], averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+
+
+NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
+
+
+@pytest.mark.parametrize("peer_modes, averaging_weights, peer_fractions", [
+    ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1)),
+    ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1)),
+    ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0)),
+    ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0)),
+    ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4)),
+    ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1)),
+    ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0)),
+    ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4)),
+])
+@pytest.mark.parametrize("part_size_bytes", [2 ** 20, 256, 19],)
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
+    """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
+
+    peers = "alice", "bob", "carol", "colab"
+
+    tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
+                       for i, peer in enumerate(peers)}
+
+    group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
+
+    servers = []
+    allreduce_protocols = []
+    peer_endpoints = {}
+
+    for peer in peers:
+        server = grpc.aio.server()
+        allreduce_protocol = AllreduceRunnerForTesting(
+            group_id=group_id, endpoint=peer, tensors=[x.clone() for x in tensors_by_peer[peer]],
+            ordered_group_endpoints=peers, peer_fractions=peer_fractions, modes=peer_modes,
+            weights=averaging_weights, peer_endpoints=peer_endpoints, part_size_bytes=part_size_bytes
+        )
+        averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(allreduce_protocol, server)
+        peer_endpoints[peer] = f"127.0.0.1:{server.add_insecure_port('127.0.0.1:*')}"
+        allreduce_protocols.append(allreduce_protocol)
+        servers.append(server)
+        await server.start()
+
+    async def _run_allreduce_inplace(allreduce: AllReduceRunner):
+        async for tensor_index, tensor_delta in aenumerate(allreduce):
+            allreduce.tensor_part_container.local_tensors[tensor_index].add_(tensor_delta)
+
+    await asyncio.gather(*map(_run_allreduce_inplace, allreduce_protocols))
+
+    reference_tensors = [sum(tensors_by_peer[peer][i] * averaging_weights[peer_index]
+                             for peer_index, peer in enumerate(peers)) / sum(averaging_weights)
+                         for i in range(len(tensors_by_peer[peers[0]]))]
+
+    for peer_index, protocol in enumerate(allreduce_protocols):
+        assert protocol._future.done()
+        if protocol.modes[peer_index] != AveragingMode.AUX:
+            targets_for_peer = reference_tensors
+        else:
+            targets_for_peer = tensors_by_peer[peers[peer_index]]
+        output_tensors = protocol.tensor_part_container.local_tensors
+        assert len(output_tensors) == len(targets_for_peer)
+        assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
+                   for our, ref in zip(output_tensors, targets_for_peer))
+
+    for server in servers:
+        await server.stop(grace=1)

+ 45 - 67
tests/test_averaging.py

@@ -1,14 +1,13 @@
-import asyncio
 import random
 import random
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
 import pytest
 import pytest
 import hivemind
 import hivemind
-from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts, AveragingMode
+from hivemind.client.averaging.allreduce import AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.key_manager import GroupKeyManager
 from hivemind.client.averaging.key_manager import GroupKeyManager
-from hivemind.utils import Endpoint
+from hivemind.proto.runtime_pb2 import CompressionType
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -47,13 +46,13 @@ def _test_allreduce_once(n_clients, n_aux):
     n_peers = 4
     n_peers = 4
     modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
     modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
     random.shuffle(modes)
     random.shuffle(modes)
-    
+
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
     peer_tensors = [tensors1, tensors2, tensors3, tensors4]
     peer_tensors = [tensors1, tensors2, tensors3, tensors4]
-    
+
     reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes)
     reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes)
                  if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))]
                  if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))]
 
 
@@ -130,6 +129,47 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
     dht.shutdown()
     dht.shutdown()
 
 
 
 
+@pytest.mark.forked
+def test_allreduce_compression():
+    """ this test ensures that compression works correctly when multiple tensors have different compression types """
+    dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
+
+    tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
+    tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
+    results = {}
+
+    FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
+
+    for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
+        averager1 = hivemind.DecentralizedAverager([x.clone() for x in tensors1], dht=dht,
+                                                   compression_type=compression_type_pair, listen=False,
+                                                   target_group_size=2, prefix='mygroup', start=True)
+        averager2 = hivemind.DecentralizedAverager([x.clone() for x in tensors2], dht=dht,
+                                                   compression_type=compression_type_pair,
+                                                   target_group_size=2, prefix='mygroup', start=True)
+
+        for future in averager1.step(wait=False), averager2.step(wait=False):
+            future.result()
+
+        with averager1.get_tensors() as averaged_tensors:
+            results[compression_type_pair] = averaged_tensors
+
+    assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
+    assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
+    assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
+    assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
+
+    assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
+    assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
+    assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
+    assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
+
+    reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
+    for i in range(2):
+        assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
+        assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
+
+
 def compute_mean_std(averagers, unbiased=True):
 def compute_mean_std(averagers, unbiased=True):
     results = []
     results = []
     for averager in averagers:
     for averager in averagers:
@@ -201,68 +241,6 @@ def test_allgather():
     dht.shutdown()
     dht.shutdown()
 
 
 
 
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_allreduce_protocol():
-    """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
-    peers = "alice", "bob", "carol", "colab"
-
-    tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
-                       for i, peer in enumerate(peers)}
-
-    group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
-    allreduce_protocols = [AllReduceProtocol(
-        group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer],
-        ordered_group_endpoints=peers, part_sizes=(150, 200, 67, 0))
-        for peer in peers]
-
-    async def _accumulate(sender: Endpoint, recipient: Endpoint):
-        sender_allreduce = allreduce_protocols[peers.index(sender)]
-        recipient_allreduce = allreduce_protocols[peers.index(recipient)]
-        averaged_part = await recipient_allreduce.accumulate_part(
-            source=sender, remote_part=sender_allreduce.local_tensor_parts[recipient])
-        sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part)
-
-    await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers
-                        if recipient != "colab"})
-
-    reference_tensors = [
-        sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
-        for i in range(len(tensors_by_peer[peers[0]]))
-    ]
-
-    for peer, allreduce in zip(peers, allreduce_protocols):
-        assert allreduce.future.done()
-        averaged_tensors = await allreduce
-        assert len(averaged_tensors) == len(reference_tensors)
-        assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
-                   for our, ref in zip(averaged_tensors, reference_tensors))
-
-
-@pytest.mark.forked
-def test_partitioning():
-    for _ in range(100):
-        tensors = []
-        for _ in range(random.randint(1, 5)):
-            ndim = random.randint(0, 4)
-            shape = torch.Size([random.randint(0, 16) for _ in range(ndim)])
-            make_tensor = random.choice([torch.rand, torch.randn, torch.zeros, torch.ones])
-            tensors.append(make_tensor(shape))
-
-        total_size = sum(map(torch.Tensor.numel, tensors))
-        if total_size == 0:
-            continue
-        num_chunks = random.randint(1, min(100, sum(x.numel() for x in tensors)))
-        part_sizes = load_balance_peers(total_size, [None] * num_chunks)
-        chunks = split_into_parts(tensors, part_sizes)
-        assert len(chunks) == num_chunks
-        shapes = [tensor.shape for tensor in tensors]
-        restored = restore_from_parts(chunks, shapes)
-        assert len(restored) == len(tensors)
-        assert all(new.shape == old.shape for new, old in zip(restored, tensors))
-        assert all(torch.allclose(new, old) for new, old in zip(restored, tensors))
-
-
 def get_cost(vector_size, partitions, throughputs):
 def get_cost(vector_size, partitions, throughputs):
     return max((vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
     return max((vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
                for i in range(len(partitions)))
                for i in range(len(partitions)))

+ 37 - 1
tests/test_util_modules.py

@@ -11,6 +11,7 @@ from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 import hivemind
 import hivemind
 from hivemind.utils import MSGPackSerializer
 from hivemind.utils import MSGPackSerializer
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
 from hivemind.utils.mpfuture import FutureStateError
 from hivemind.utils.mpfuture import FutureStateError
 
 
 
 
@@ -142,6 +143,7 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     for compression_type in CompressionType.values():
     for compression_type in CompressionType.values():
         assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
         assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
 
 
+
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_channel_cache():
 async def test_channel_cache():
@@ -256,7 +258,7 @@ def test_split_parts():
     for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
     for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
         with pytest.raises(RuntimeError):
         with pytest.raises(RuntimeError):
             deserialize_torch_tensor(combined)
             deserialize_torch_tensor(combined)
-            # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceProtocol
+            # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceRunner
 
 
 
 
 def test_generic_data_classes():
 def test_generic_data_classes():
@@ -271,3 +273,37 @@ def test_generic_data_classes():
     sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)])
     sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)])
     sorted_heap_entries = sorted([HeapEntry(DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
     sorted_heap_entries = sorted([HeapEntry(DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
     assert all([entry.expiration_time == value for entry, value in zip(sorted_heap_entries, sorted_expirations)])
     assert all([entry.expiration_time == value for entry, value in zip(sorted_heap_entries, sorted_expirations)])
+
+
+@pytest.mark.asyncio
+async def test_asyncio_utils():
+    res = [i async for i, item in aenumerate(aiter('a', 'b', 'c'))]
+    assert res == list(range(len(res)))
+
+    num_steps = 0
+    async for elem in amap_in_executor(lambda x: x ** 2, aiter(*range(100)), max_prefetch=5):
+        assert elem == num_steps ** 2
+        num_steps += 1
+    assert num_steps == 100
+
+    ours = [elem async for elem in amap_in_executor(max, aiter(*range(7)), aiter(*range(-50, 50, 10)), max_prefetch=1)]
+    ref = list(map(max, range(7), range(-50, 50, 10)))
+    assert ours == ref
+
+    ours = [row async for row in azip(aiter('a', 'b', 'c'), aiter(1, 2, 3))]
+    ref = list(zip(['a', 'b', 'c'], [1, 2, 3]))
+    assert ours == ref
+
+    async def _aiterate():
+        yield 'foo'
+        yield 'bar'
+        yield 'baz'
+
+    iterator = _aiterate()
+    assert (await anext(iterator)) == 'foo'
+    tail = [item async for item in iterator]
+    assert tail == ['bar', 'baz']
+    with pytest.raises(StopAsyncIteration):
+        await anext(iterator)
+
+    assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ['foo', 'bar', 'baz'] + list(range(5))