Преглед на файлове

Add Averager load balancing and public endpoints (#140)

* implement LP load balancing
* averager now relies on DHT to get public endpoint
* scipy to requirements

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic преди 4 години
родител
ревизия
8466d722da

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.25'
+__version__ = '0.8.26'

+ 37 - 24
hivemind/client/averaging/__init__.py

@@ -12,18 +12,20 @@ from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 
 
 import grpc
 import grpc
 import torch
 import torch
+import numpy as np
 
 
 import hivemind
 import hivemind
 from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
 from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
 from hivemind.client.averaging.matchmaking import Matchmaking
 from hivemind.client.averaging.matchmaking import Matchmaking
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
-from hivemind.utils import get_logger, Endpoint, Port, MPFuture, replace_port, GRPC_KEEPALIVE_OPTIONS, get_dht_time
+from hivemind.utils import get_logger, Endpoint, Port, MPFuture, GRPC_KEEPALIVE_OPTIONS, get_dht_time, MSGPackSerializer
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 
 
 # flavour types
 # flavour types
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 
 
 INITIAL_GROUP_NBITS = 3
 INITIAL_GROUP_NBITS = 3
+DataForGather = Any
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
@@ -52,6 +54,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
     :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
     :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
+    :param throughput: if specified, this value represents the network bandwidth available to averager.
+          By default, the averager is assumed to have the average bandwidth of his group.
+          If throughput == 0, averager will run in client-only mode (TODO not implemented yet!)
     :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
     :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
             if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
             if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
@@ -65,15 +70,21 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     """
     """
     _matchmaking: Matchmaking
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
     _pending_group_assembled: asyncio.Event
+    serializer = MSGPackSerializer
 
 
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
                  prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
                  prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
                  averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
                  averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
                  allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
                  allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
-                 listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
+                 throughput: Optional[float] = None, min_vector_size: int = 0,
+                 listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
+        assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), "throughput must be a" \
+                                                                                                " nonnegative float32"
+        if not listen:
+            raise NotImplementedError("Client-only averaging is not implemented yet.")
         if not is_power_of_two(target_group_size):
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
             logger.warning("It is recommended to set target_group_size to a power of 2.")
         if initial_group_bits is None:
         if initial_group_bits is None:
@@ -96,8 +107,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self.matchmaking_kwargs = dict(
         self.matchmaking_kwargs = dict(
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
             min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout,
             min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout,
-            chunk_size_bytes=chunk_size_bytes, compression_type=compression_type)
-        self.averaging_alpha, self.allreduce_timeout = averaging_alpha, allreduce_timeout
+            chunk_size_bytes=chunk_size_bytes, compression_type=compression_type,
+            throughput=throughput, min_vector_size=min_vector_size)
+        self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
 
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
@@ -114,8 +126,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
     @property
     @property
     def endpoint(self) -> Endpoint:
     def endpoint(self) -> Endpoint:
+        assert self.port is not None, "Averager is not running yet"
         if self._averager_endpoint is None:
         if self._averager_endpoint is None:
-            self._averager_endpoint = replace_port(self.listen_on, self.port if self.port is not None else '*')
+            self._averager_endpoint = f"{self.dht.get_visible_address()}:{self.port}"
             logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
             logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
         return self._averager_endpoint
         return self._averager_endpoint
 
 
@@ -165,48 +178,51 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         else:
         else:
             logger.warning("DHT shutdown has no effect: the process is not alive")
             logger.warning("DHT shutdown has no effect: the process is not alive")
 
 
-    def step(self, allow_retries: bool = True, timeout: Optional[float] = None, wait=True
-             ) -> Union[bool, MPFuture]:
+    def step(self, allow_retries: bool = True, gather: Optional[DataForGather] = None, timeout: Optional[float] = None,
+             wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
         """
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
+
         :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
         :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
           within the specified timeout
           within the specified timeout
+        :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
+          (this operation is known as all-gather). The gathered data will be available as the output of this function.
         :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
         :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
         :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
         :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         """
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_step', [], dict(future=_future, allow_retries=allow_retries, timeout=timeout)))
+        self.pipe.send(('_step', [], dict(future=_future, gather=gather, allow_retries=allow_retries, timeout=timeout)))
         return future.result() if wait else future
         return future.result() if wait else future
 
 
-    async def _step(self, *, future: MPFuture, allow_retries: bool, timeout: Optional[float]):
+    async def _step(self, *, future: MPFuture, gather: DataForGather, allow_retries: bool, timeout: Optional[float]):
         loop = asyncio.get_event_loop()
         loop = asyncio.get_event_loop()
         start_time = get_dht_time()
         start_time = get_dht_time()
-
-        try_averaging = True
         group_id = None
         group_id = None
 
 
-        while try_averaging:
+        while not future.done():
             try:
             try:
                 self._pending_group_assembled.clear()
                 self._pending_group_assembled.clear()
-                allreduce_group = await self._matchmaking.look_for_group(timeout=timeout)
+                gather_binary = self.serializer.dumps(gather)
+                allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
                 if allreduce_group is None:
                 if allreduce_group is None:
                     raise AllreduceException("Averaging step failed: could not find a group.")
                     raise AllreduceException("Averaging step failed: could not find a group.")
 
 
                 group_id = allreduce_group.group_id
                 group_id = allreduce_group.group_id
                 self._running_groups[group_id] = allreduce_group
                 self._running_groups[group_id] = allreduce_group
                 self._pending_group_assembled.set()
                 self._pending_group_assembled.set()
-                await asyncio.wait_for(allreduce_group.run(), self.allreduce_timeout)
-                update_ok = await loop.run_in_executor(None, self.update_tensors, allreduce_group)
+                await asyncio.wait_for(allreduce_group.run(), self._allreduce_timeout)
+                await loop.run_in_executor(None, self.update_tensors, allreduce_group)
 
 
                 # averaging is finished, exit the loop
                 # averaging is finished, exit the loop
-                future.set_result(update_ok)
-                try_averaging = False
+                gathered_items = map(self.serializer.loads, allreduce_group.gathered)
+                gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
+                future.set_result(gathered_data_by_peer)
 
 
             except AllreduceException:
             except AllreduceException:
                 time_elapsed = get_dht_time() - start_time
                 time_elapsed = get_dht_time() - start_time
                 if not allow_retries or (timeout is not None and timeout < time_elapsed):
                 if not allow_retries or (timeout is not None and timeout < time_elapsed):
-                    future.set_result(False)
-                    try_averaging = False
+                    future.set_result(None)
 
 
             except Exception as e:
             except Exception as e:
                 future.set_exception(e)
                 future.set_exception(e)
@@ -215,11 +231,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 _ = self._running_groups.pop(group_id, None)
                 _ = self._running_groups.pop(group_id, None)
                 self._pending_group_assembled.set()
                 self._pending_group_assembled.set()
 
 
-    def update_tensors(self, allreduce_group: AllReduceRunner) -> bool:
+    def update_tensors(self, allreduce_group: AllReduceRunner):
         """
         """
         a private (extendable) method that applies changes from a finished allreduce to local tensors
         a private (extendable) method that applies changes from a finished allreduce to local tensors
-
-        :return: True on success, False on failure
         """
         """
         assert allreduce_group.return_deltas and allreduce_group.future.done()
         assert allreduce_group.return_deltas and allreduce_group.future.done()
         averaging_deltas = allreduce_group.future.result()
         averaging_deltas = allreduce_group.future.result()
@@ -227,8 +241,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         with torch.no_grad(), self.get_tensors() as local_tensors:
         with torch.no_grad(), self.get_tensors() as local_tensors:
             assert len(local_tensors) == len(self._averaged_tensors)
             assert len(local_tensors) == len(self._averaged_tensors)
             for tensor, update in zip(local_tensors, averaging_deltas):
             for tensor, update in zip(local_tensors, averaging_deltas):
-                tensor.add_(update, alpha=self.averaging_alpha)
-            return True
+                tensor.add_(update, alpha=self._averaging_alpha)
 
 
     @contextlib.contextmanager
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
     def get_tensors(self) -> Sequence[torch.Tensor]:

+ 16 - 14
hivemind/client/averaging/allreduce.py

@@ -1,5 +1,5 @@
 import asyncio
 import asyncio
-from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Iterator
+from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any
 
 
 import grpc
 import grpc
 import torch
 import torch
@@ -20,15 +20,17 @@ class AllReduceProtocol:
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param endpoint: your endpoint, must be included in ordered_group_endpoints
     :param endpoint: your endpoint, must be included in ordered_group_endpoints
     :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
     :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
+    :param part_sizes: for each peer, a number of vector elements that this peer is responsible for averaging
     :param return_deltas: if True, returns the element-wise differences (averaged_tensors - original_tensors)
     :param return_deltas: if True, returns the element-wise differences (averaged_tensors - original_tensors)
            default (False) - return averaged_tensors by themselves
            default (False) - return averaged_tensors by themselves
     """
     """
 
 
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], return_deltas: bool = False):
+                 ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False):
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
-        self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
-        self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, self.group_size)))
+        self.group_id, self.endpoint = group_id, endpoint
+        self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
+        self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
         self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
         self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
         self.return_deltas = return_deltas
         self.return_deltas = return_deltas
 
 
@@ -121,17 +123,18 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
 
 
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
                  ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
                  ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
-                 chunk_size_bytes: int, return_deltas: bool = False):
-        super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint,
+                 chunk_size_bytes: int, part_sizes: Tuple[int, ...], gathered: Sequence[Any] = (),
+                 return_deltas: bool = False):
+        super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
                          ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
                          ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
-        self.compression_type, self.chunk_size_bytes = compression_type, chunk_size_bytes
+        self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
         self.averaged_part_stream: asyncio.Future[Tuple[runtime_pb2.Tensor, ...]] = asyncio.Future()
         self.averaged_part_stream: asyncio.Future[Tuple[runtime_pb2.Tensor, ...]] = asyncio.Future()
 
 
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
 
 
-    async def _average_one_part(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
-        """ Send one part of local tensors to one groupmate and collect the average for this part """
+    async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
+        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
         serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
         serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
         chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes)
         chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes)
 
 
@@ -163,7 +166,7 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
         send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
         send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
         """
         """
         try:
         try:
-            await asyncio.gather(self, *(self._average_one_part(peer, part)
+            await asyncio.gather(self, *(self._communicate_with_peer(peer, part)
                                          for peer, part in self.local_tensor_parts.items() if peer != self.endpoint))
                                          for peer, part in self.local_tensor_parts.items() if peer != self.endpoint))
             return await self
             return await self
         except BaseException as e:
         except BaseException as e:
@@ -203,6 +206,7 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
                 yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=next(averaged_chunks))
                 yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=next(averaged_chunks))
                 for averaged_chunk in averaged_chunks:
                 for averaged_chunk in averaged_chunks:
                     yield averaging_pb2.AveragingData(tensor_part=averaged_chunk)
                     yield averaging_pb2.AveragingData(tensor_part=averaged_chunk)
+
             except Exception as e:
             except Exception as e:
                 self.set_exception(e)
                 self.set_exception(e)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
@@ -213,12 +217,10 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
 
 
 
 
-def split_into_parts(tensors: Sequence[torch.Tensor], group_size: int) -> Tuple[torch.Tensor, ...]:
+def split_into_parts(tensors: Sequence[torch.Tensor], part_sizes: Tuple[int]) -> Tuple[torch.Tensor, ...]:
     """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
     """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
     flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
     flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
-    chunk_slices = torch.linspace(start=0, end=len(flat_tensor), steps=group_size + 1, dtype=torch.int64)
-    chunk_slices[-1] = len(flat_tensor)
-    return tuple(flat_tensor[chunk_slices[i]: chunk_slices[i + 1]] for i in range(group_size))
+    return torch.split_with_sizes(flat_tensor, part_sizes, dim=0)
 
 
 
 
 def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:
 def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:

+ 98 - 0
hivemind/client/averaging/load_balancing.py

@@ -0,0 +1,98 @@
+from typing import Sequence, Optional, Tuple
+import numpy as np
+import scipy.optimize
+
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+def load_balance_peers(vector_size, throughputs: Sequence[Optional[float]], min_size: int = 0) -> Tuple[int, ...]:
+    """
+    Find an optimal partitioning of weights for butterfly all-reduce given peer throughputs.
+    :param vector_size: total size of the averaged vector (in elements, not bytes)
+    :param throughputs: 1d array of non-negative throughputs for each peer, typically min(upload speed, download speed)
+    :param min_size: peers that can aggregate less than this many elements will be assigned nothing
+    :returns: an integer array where i-th element is the number of weights assigned to i-th peer
+    """
+    specified_throughputs = [throughput for throughput in throughputs if throughput is not None and throughput > 0]
+
+    if specified_throughputs:
+        default_throughput = np.mean(specified_throughputs)
+        throughputs = [throughput if throughput is not None else default_throughput for throughput in throughputs]
+        scores = optimize_parts_lp(vector_size, np.asarray(throughputs), min_size)
+    else:
+        assert not all(throughput == 0 for throughput in throughputs), "Must have at least one nonzero throughput"
+        scores = np.asarray([1.0 if throughput is None else 0.0 for throughput in throughputs])
+
+    return tuple(hagenbach_bishoff(vector_size, scores))
+
+
+def optimize_parts_lp(vector_size: int, throughputs: np.ndarray, min_size: int = 0, eps: float = 1e-15) -> np.ndarray:
+    """
+    This method solves an optimization problem to minimize the total allreduce time.
+    In butterfly all-reduce, each peer acts both as a "client" and as an "aggregator":
+    * a "client" splits his local vector into shards and sends each shard to one peer, then downloads the average
+    * an "aggregator" receives a certain part of vector components from all peers, aggregates and returns the average
+
+    Peer i network load as a "client" = vector_size * (1 - fraction_assigned_to_peer_i)
+    Peer i network load as an "aggregator" = vector_size * (group_size - 1) * fraction_assigned_to_peer_i
+    Peer i total communication = vector_size * [1 + (group_size - 2) * fraction_assigned_to_peer_i]
+    Total time = max_i (total_communication_for_peer_i / throughputs[i])
+
+    We solve this optimization problem by reducing it to linear programming with a minimax reduction
+    (see lecture notes: https://www.usna.edu/Users/math/dphillip/sa305.s15/phillips/lessons/32/32.pdf )
+
+    :returns: a vector of "scores", i-th score is proportional to the fraction of weights assigned to i-th peer
+    """
+    assert np.all(throughputs >= 0) and np.any(throughputs > 0)
+    permutation = np.argsort(-throughputs)
+    throughputs = throughputs[permutation]
+    is_nonzero = throughputs != 0
+
+    group_size = len(throughputs)
+    num_variables = group_size + 1  # [w_1, ..., w_N, xi]
+
+    c = np.zeros(num_variables)
+    c[-1] = 1.0  # optimize w.r.t. xi
+
+    # the constraints below are tuples (A, b) such that Ax <= b
+    nonnegative_weights = -np.eye(group_size, M=num_variables), np.zeros(group_size)
+    weights_sum_to_one = c[None, :] - 1.0, np.array([-1.0])
+    coeff_per_variable = (group_size - 2.0) / np.maximum(throughputs, eps)
+    coeff_matrix_minus_xi = np.hstack([np.diag(coeff_per_variable), -np.ones((group_size, 1))])
+    xi_is_maximum = coeff_matrix_minus_xi[is_nonzero], -1.0 / throughputs[is_nonzero]
+    force_max_weights = np.eye(group_size, M=num_variables), is_nonzero.astype(c.dtype)
+
+    A, b = list(map(np.concatenate, zip(nonnegative_weights, weights_sum_to_one, xi_is_maximum, force_max_weights)))
+
+    solution = scipy.optimize.linprog(c, A_ub=A, b_ub=b)
+    if solution.success:
+        peer_scores = solution.x[:group_size]
+        # if some peers have less than min_size elements, transfer their share to other peers (if any)
+        if np.max(peer_scores) >= min_size / float(vector_size):
+            peer_scores[peer_scores < min_size / float(vector_size)] = 0.0
+    else:
+        logger.error(f"Failed to solve load-balancing for bandwidths {throughputs}.")
+        peer_scores = np.ones(group_size)
+
+    return peer_scores[np.argsort(permutation)]
+
+
+def hagenbach_bishoff(vector_size: int, scores: Sequence[float]) -> Sequence[int]:
+    """
+    Split a vector between participants based on continuous fractions.
+    https://en.wikipedia.org/wiki/Hagenbach-Bischoff_system
+    The code is based on https://github.com/crflynn/voting
+
+    :param vector_size: the total number of elements to be split
+    :param scores: real-valued vector fractions for each peer
+    :returns: integer-valued partitions assigned to every peer
+    """
+    total_score = sum(scores)
+    allocated = [int(vector_size * score_i / total_score) for score_i in scores]
+    while sum(allocated) < vector_size:
+        quotients = [score / (allocated[idx] + 1) for idx, score in enumerate(scores)]
+        idx_max = quotients.index(max(quotients))
+        allocated[idx_max] += 1
+    return allocated

+ 48 - 20
hivemind/client/averaging/matchmaking.py

@@ -6,20 +6,20 @@ import contextlib
 import random
 import random
 from dataclasses import asdict
 from dataclasses import asdict
 from math import isfinite
 from math import isfinite
-from typing import Sequence, Optional, AsyncIterator, Set, Tuple
+from typing import Sequence, Optional, AsyncIterator, Set, Tuple, Dict
 import asyncio
 import asyncio
 
 
-import torch
 import grpc
 import grpc
+import torch
 
 
 import hivemind
 import hivemind
-from hivemind.client.averaging.allreduce import AllReduceRunner, GroupID
+from hivemind.client.averaging.allreduce import AllReduceRunner
+from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.dht import DHTID, DHTExpiration, get_dht_time, GroupKey
 from hivemind.dht import DHTID, DHTExpiration, get_dht_time, GroupKey
 from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, TimedStorage
 from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, TimedStorage
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
+from hivemind.proto import averaging_pb2, averaging_pb2_grpc
 from hivemind.utils.grpc import ChannelCache
 from hivemind.utils.grpc import ChannelCache
 
 
-
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
@@ -34,12 +34,12 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
       This deadlock only happens if averagers have outdated information on expirations (due to network delays). 
       This deadlock only happens if averagers have outdated information on expirations (due to network delays). 
       While A->B->A deadlock is easy to fix, it gets much harder with more peers (e.g. A -> B -> C -> D -> A).
       While A->B->A deadlock is easy to fix, it gets much harder with more peers (e.g. A -> B -> C -> D -> A).
       Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
       Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
-    
     """
     """
 
 
     def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *,
     def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *,
                  prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
                  prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15, request_timeout: float, **allreduce_kwargs):
+                 averaging_expiration: float = 15, request_timeout: float, throughput: Optional[float] = None,
+                 min_vector_size: int, **allreduce_kwargs):
         assert '.' not in prefix, "group prefix must be a string without ."
         assert '.' not in prefix, "group prefix must be a string without ."
         if request_timeout is None or request_timeout >= averaging_expiration:
         if request_timeout is None or request_timeout >= averaging_expiration:
             logger.warning("It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
             logger.warning("It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
@@ -50,8 +50,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         self.prefix, self.group_bits = prefix, initial_group_bits
         self.prefix, self.group_bits = prefix, initial_group_bits
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
+        self.throughput, self.min_vector_size = throughput, min_vector_size
         self.allreduce_kwargs = allreduce_kwargs
         self.allreduce_kwargs = allreduce_kwargs
         self.schema_hash = compute_schema_hash(self.averaged_tensors)
         self.schema_hash = compute_schema_hash(self.averaged_tensors)
+        self.total_size = sum(tensor.numel() for tensor in self.averaged_tensors)
 
 
         self.lock_looking_for_group = asyncio.Lock()
         self.lock_looking_for_group = asyncio.Lock()
         self.lock_request_join_group = asyncio.Lock()
         self.lock_request_join_group = asyncio.Lock()
@@ -60,8 +62,9 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         self.assembled_group = asyncio.Future()
         self.assembled_group = asyncio.Future()
 
 
         self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
         self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
-        self.current_followers: Set[Endpoint] = set()  # iff i am a leader, this contains my followers excluding myself
+        self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
         self.potential_leaders = PotentialLeaders(endpoint, dht, averaging_expiration, target_group_size)
         self.potential_leaders = PotentialLeaders(endpoint, dht, averaging_expiration, target_group_size)
+        self.data_for_gather: bytes = None
 
 
     @property
     @property
     def is_looking_for_group(self):
     def is_looking_for_group(self):
@@ -82,8 +85,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
         return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
                f" current key = {self.current_group_key})"
                f" current key = {self.current_group_key})"
 
 
-    async def look_for_group(self, *, timeout: Optional[float] = None) -> Optional[AllReduceRunner]:
+    async def look_for_group(self, *, data_for_gather: bytes = b'', timeout: Optional[float] = None
+                             ) -> Optional[AllReduceRunner]:
         """
         """
+        :param gather: optionally send this data to all peers in the next group and gather it from every groupmate
+        :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         """
         """
@@ -91,6 +97,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             logger.info("Another look_for_group is already in progress. The current run will be scheduled after"
             logger.info("Another look_for_group is already in progress. The current run will be scheduled after"
                         " the existing group is either assembled or disbanded.")
                         " the existing group is either assembled or disbanded.")
         async with self.lock_looking_for_group:
         async with self.lock_looking_for_group:
+            self.data_for_gather = data_for_gather
             request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
             request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
             try:
             try:
                 return await asyncio.wait_for(self.assembled_group, timeout=timeout)
                 return await asyncio.wait_for(self.assembled_group, timeout=timeout)
@@ -116,6 +123,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 # note: the code above ensures that we send all followers away before creating new future
                 # note: the code above ensures that we send all followers away before creating new future
                 self.assembled_group = asyncio.Future()
                 self.assembled_group = asyncio.Future()
                 self.was_accepted_to_group.clear()
                 self.was_accepted_to_group.clear()
+                self.data_for_gather = None
 
 
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
         """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
         """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
@@ -161,7 +169,9 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             async with self.lock_request_join_group:
             async with self.lock_request_join_group:
                 leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
                 leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
                 call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
                 call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
-                    endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time))
+                    endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time,
+                    throughput=self.throughput if self.throughput is not None else -1.0,
+                    gather=self.data_for_gather))
                 message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
                 message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
 
 
                 if message.code == averaging_pb2.ACCEPTED:
                 if message.code == averaging_pb2.ACCEPTED:
@@ -182,8 +192,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
                     async with self.lock_request_join_group:
                     async with self.lock_request_join_group:
-                        return await self.follower_assemble_group(
-                            leader, message.group_id, message.ordered_group_endpoints)
+                        return await self.follower_assemble_group(leader, message)
 
 
             if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
             if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
                 if message.suggested_leader and message.suggested_leader != self.endpoint:
                 if message.suggested_leader and message.suggested_leader != self.endpoint:
@@ -218,7 +227,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                     yield reason_to_reject
                     yield reason_to_reject
                     return
                     return
 
 
-                self.current_followers.add(request.endpoint)
+                self.current_followers[request.endpoint] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
 
                 if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
                 if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
@@ -253,14 +262,15 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             allreduce_group = self.assembled_group.result()
             allreduce_group = self.assembled_group.result()
             yield averaging_pb2.MessageFromLeader(
             yield averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
                 code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
-                ordered_group_endpoints=allreduce_group.ordered_group_endpoints)
+                ordered_group_endpoints=allreduce_group.ordered_group_endpoints,
+                part_sizes=allreduce_group.part_sizes, gathered=allreduce_group.gathered)
 
 
         except Exception as e:
         except Exception as e:
             logger.exception(e)
             logger.exception(e)
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
 
 
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
-            self.current_followers.discard(request.endpoint)
+            self.current_followers.pop(request.endpoint, None)
             self.follower_was_discarded.set()
             self.follower_was_discarded.set()
 
 
     def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> Optional[averaging_pb2.MessageFromLeader]:
     def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> Optional[averaging_pb2.MessageFromLeader]:
@@ -297,22 +307,40 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         ordered_group_endpoints = list(self.current_followers)
         ordered_group_endpoints = list(self.current_followers)
         ordered_group_endpoints.append(self.endpoint)
         ordered_group_endpoints.append(self.endpoint)
         random.shuffle(ordered_group_endpoints)
         random.shuffle(ordered_group_endpoints)
+
+        throughputs, gathered = [], []
+        for endpoint in ordered_group_endpoints:
+            if endpoint == self.endpoint:
+                throughputs.append(self.throughput)
+                gathered.append(self.data_for_gather)
+            else:
+                follower_info = self.current_followers[endpoint]
+                throughputs.append(follower_info.throughput if follower_info.throughput >= 0 else None)
+                gathered.append(follower_info.gather if follower_info.gather else None)
+
+        part_sizes = load_balance_peers(self.total_size, throughputs, self.min_vector_size)
+
         logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
         logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
-                                          ordered_group_endpoints=ordered_group_endpoints, **self.allreduce_kwargs)
+                                          ordered_group_endpoints=ordered_group_endpoints, part_sizes=part_sizes,
+                                          gathered=gathered, **self.allreduce_kwargs)
         self.assembled_group.set_result(allreduce_group)
         self.assembled_group.set_result(allreduce_group)
         return allreduce_group
         return allreduce_group
 
 
-    async def follower_assemble_group(self, leader: Endpoint, group_id: GroupID,
-                                      ordered_group_endpoints: Sequence[Endpoint]) -> AllReduceRunner:
+    async def follower_assemble_group(self, leader: Endpoint, msg: averaging_pb2.MessageFromLeader) -> AllReduceRunner:
         """ Prepare to run allreduce using a list of peers provided by our leader """
         """ Prepare to run allreduce using a list of peers provided by our leader """
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
         assert not self.assembled_group.done()
         assert not self.assembled_group.done()
-        logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
+
+        group_id, ordered_group_endpoints, part_sizes = msg.group_id, msg.ordered_group_endpoints, msg.part_sizes
         assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
         assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
+        assert len(ordered_group_endpoints) == len(part_sizes) == len(msg.gathered)
+
+        logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
-                                          ordered_group_endpoints=ordered_group_endpoints, **self.allreduce_kwargs)
+                                          ordered_group_endpoints=tuple(ordered_group_endpoints),
+                                          part_sizes=tuple(part_sizes), gathered=msg.gathered, **self.allreduce_kwargs)
         self.assembled_group.set_result(allreduce_group)
         self.assembled_group.set_result(allreduce_group)
         return allreduce_group
         return allreduce_group
 
 

+ 6 - 2
hivemind/proto/averaging.proto

@@ -33,6 +33,8 @@ message JoinRequest {
   string endpoint = 1;          // A follower accepts incoming allreduce requests at this address
   string endpoint = 1;          // A follower accepts incoming allreduce requests at this address
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
+  bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
+  float throughput = 5;         // Follower has this bandwidth for averaging (0 = default, negative = client only)
 }
 }
 
 
 message MessageFromLeader {
 message MessageFromLeader {
@@ -40,11 +42,13 @@ message MessageFromLeader {
   bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
   bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
   string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
   string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
   repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
   repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
+  repeated int32 part_sizes = 5;  // a sequence of tensor parts assigned to each peer, same order as endpoints
+  repeated bytes gathered = 6;  // metadata (gather) from all groupmates in the same order as their endoints
 }
 }
 
 
 message AveragingData {
 message AveragingData {
   MessageCode code = 1;     // in case of a protocol violation, this will be the error message
   MessageCode code = 1;     // in case of a protocol violation, this will be the error message
-  bytes group_id = 2;        // a unique group identifier, same as in MessageFromLeader
+  bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
   string endpoint = 3;      // sender's rpc endpoint, used for coordination
   string endpoint = 3;      // sender's rpc endpoint, used for coordination
-  Tensor tensor_part = 4;    // either peer's local tensor part (rpc input) or group average of this part (rpc output)
+  Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
 }
 }

+ 1 - 0
requirements.txt

@@ -1,6 +1,7 @@
 PyYAML
 PyYAML
 torch>=1.6.0
 torch>=1.6.0
 numpy>=1.17
 numpy>=1.17
+scipy>=1.2.1
 prefetch_generator>=1.0.1
 prefetch_generator>=1.0.1
 msgpack>=0.5.6
 msgpack>=0.5.6
 sortedcontainers
 sortedcontainers

+ 76 - 5
tests/test_averaging.py

@@ -1,11 +1,12 @@
 import asyncio
 import asyncio
 import random
 import random
-import time
 
 
+import numpy as np
 import torch
 import torch
 import pytest
 import pytest
 import hivemind
 import hivemind
 from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
 from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
+from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.utils import Endpoint
 from hivemind.utils import Endpoint
 
 
 
 
@@ -35,7 +36,7 @@ def test_getset_averagers():
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_allreduce_once():
 def test_allreduce_once():
-    dht = hivemind.DHT(start=True)
+    dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
 
 
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
@@ -53,7 +54,9 @@ def test_allreduce_once():
     for averager in averagers:
     for averager in averagers:
         futures.append(averager.step(wait=False))
         futures.append(averager.step(wait=False))
     for future in futures:
     for future in futures:
-        assert future.result() is True
+        result = future.result()
+        for averager in averagers:
+            assert averager.endpoint in result
 
 
     for averager in averagers:
     for averager in averagers:
         with averager.get_tensors() as averaged_tensors:
         with averager.get_tensors() as averaged_tensors:
@@ -61,6 +64,31 @@ def test_allreduce_once():
                 assert torch.allclose(ref, our, atol=1e-6)
                 assert torch.allclose(ref, our, atol=1e-6)
 
 
 
 
+@pytest.mark.forked
+def test_allgather():
+    dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
+    averagers = [hivemind.DecentralizedAverager(torch.ones(1), dht=dht, target_group_size=4, averaging_expiration=15,
+                                                prefix='mygroup', initial_group_bits='000', listen_on='127.0.0.1:*',
+                                                start=True)
+                 for _ in range(8)]
+
+    futures = []
+    for i, averager in enumerate(averagers):
+        futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo='bar')))
+
+    assert len(set(repr(sorted(future.result())) for future in futures)) == 2
+
+    reference_metadata = {averager.endpoint: dict(batch_size=123 + i, foo='bar')
+                          for i, averager in enumerate(averagers)}
+    for future in futures:
+        gathered = future.result()
+
+        assert len(gathered) == 4
+
+        for endpoint in gathered:
+            assert gathered[endpoint] == reference_metadata[endpoint]
+
+
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_allreduce_protocol():
 async def test_allreduce_protocol():
@@ -72,7 +100,8 @@ async def test_allreduce_protocol():
 
 
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
     allreduce_protocols = [AllReduceProtocol(
     allreduce_protocols = [AllReduceProtocol(
-        group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer], ordered_group_endpoints=peers)
+        group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer],
+        ordered_group_endpoints=peers, part_sizes=(150, 200, 67))
         for peer in peers]
         for peer in peers]
 
 
     async def _accumulate(sender: Endpoint, recipient: Endpoint):
     async def _accumulate(sender: Endpoint, recipient: Endpoint):
@@ -112,10 +141,52 @@ def test_partitioning():
         if total_size == 0:
         if total_size == 0:
             continue
             continue
         num_chunks = random.randint(1, min(1000, sum(x.numel() for x in tensors)))
         num_chunks = random.randint(1, min(1000, sum(x.numel() for x in tensors)))
-        chunks = split_into_parts(tensors, group_size=num_chunks)
+        part_sizes = load_balance_peers(total_size, [None] * num_chunks)
+        chunks = split_into_parts(tensors, part_sizes)
         assert len(chunks) == num_chunks
         assert len(chunks) == num_chunks
         shapes = [tensor.shape for tensor in tensors]
         shapes = [tensor.shape for tensor in tensors]
         restored = restore_from_parts(chunks, shapes)
         restored = restore_from_parts(chunks, shapes)
         assert len(restored) == len(tensors)
         assert len(restored) == len(tensors)
         assert all(new.shape == old.shape for new, old in zip(restored, tensors))
         assert all(new.shape == old.shape for new, old in zip(restored, tensors))
         assert all(torch.allclose(new, old) for new, old in zip(restored, tensors))
         assert all(torch.allclose(new, old) for new, old in zip(restored, tensors))
+
+
+def get_cost(vector_size, partitions, throughputs):
+    return max((vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
+               for i in range(len(partitions)))
+
+
+def check_optimality(vector_size, throughputs, ref_partitions):
+    partitions = list(load_balance_peers(vector_size, throughputs))
+    assert get_cost(vector_size, partitions, throughputs) <= get_cost(vector_size, ref_partitions, throughputs)
+
+
+@pytest.mark.forked
+def test_load_balancing():
+    check_optimality(60, np.array([0.25, 0.25, 0.25, 0.25]), [15, 15, 15, 15])
+    check_optimality(1024, np.array([0.3, 0.5, 0.9]), [0, 255, 769])
+    check_optimality(60, np.array([0.44, 0.33, 0.22]), [42, 18, 0])
+    check_optimality(60, np.array([0.55, 0.44, 0.40]), [35, 16, 9])
+    check_optimality(1024 * 1024, np.array([0.3, 0.5, 0.9, 0.6]), [0, 169327, 602629, 276620])
+    check_optimality(1024 * 1024, np.array([0.0, 0.5, 0.0, 0.6]), [0, 428963, 0, 619613])
+    assert load_balance_peers(60, np.array([0.55, 0.44, 0.40]), min_size=10) == (41, 19, 0)
+    assert load_balance_peers(60, np.array([0.32, 0.55, 0.44]), min_size=10) == (0, 40, 20)
+    assert load_balance_peers(2, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 1)
+    assert load_balance_peers(1, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 0)
+
+    assert load_balance_peers(100, (None, None)) == (50, 50)
+    assert load_balance_peers(100, (None, None, None, None, None)) == (20, 20, 20, 20, 20)
+    assert load_balance_peers(100, (0, 0, 0, None, None)) == (0, 0, 0, 50, 50)
+
+    with pytest.raises(AssertionError):
+        load_balance_peers(100, (0, 0, 0))
+
+    for i in range(10):
+        vector_size = np.random.randint(1, 1024 ** 3)
+        num_peers = np.random.randint(1, 256)
+        scale = 1e-9 + np.random.rand() * 1e5
+        throughputs = np.random.rand(num_peers) * scale + 1e-6
+        min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
+        assignment = load_balance_peers(vector_size, throughputs, min_size)
+        assert np.sum(assignment) == vector_size
+        assert np.min(assignment) >= 0