Browse Source

Add Averager load balancing and public endpoints (#140)

* implement LP load balancing
* averager now relies on DHT to get public endpoint
* scipy to requirements

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 years ago
parent
commit
8466d722da

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.25'
+__version__ = '0.8.26'

+ 37 - 24
hivemind/client/averaging/__init__.py

@@ -12,18 +12,20 @@ from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 
 import grpc
 import torch
+import numpy as np
 
 import hivemind
 from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
 from hivemind.client.averaging.matchmaking import Matchmaking
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
-from hivemind.utils import get_logger, Endpoint, Port, MPFuture, replace_port, GRPC_KEEPALIVE_OPTIONS, get_dht_time
+from hivemind.utils import get_logger, Endpoint, Port, MPFuture, GRPC_KEEPALIVE_OPTIONS, get_dht_time, MSGPackSerializer
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 
 # flavour types
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 
 INITIAL_GROUP_NBITS = 3
+DataForGather = Any
 logger = get_logger(__name__)
 
 
@@ -52,6 +54,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
     :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
+    :param throughput: if specified, this value represents the network bandwidth available to averager.
+          By default, the averager is assumed to have the average bandwidth of his group.
+          If throughput == 0, averager will run in client-only mode (TODO not implemented yet!)
     :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
             if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
@@ -65,15 +70,21 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     """
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
+    serializer = MSGPackSerializer
 
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
                  prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
                  averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
                  allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
-                 listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
+                 throughput: Optional[float] = None, min_vector_size: int = 0,
+                 listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
+        assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), "throughput must be a" \
+                                                                                                " nonnegative float32"
+        if not listen:
+            raise NotImplementedError("Client-only averaging is not implemented yet.")
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
         if initial_group_bits is None:
@@ -96,8 +107,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self.matchmaking_kwargs = dict(
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
             min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout,
-            chunk_size_bytes=chunk_size_bytes, compression_type=compression_type)
-        self.averaging_alpha, self.allreduce_timeout = averaging_alpha, allreduce_timeout
+            chunk_size_bytes=chunk_size_bytes, compression_type=compression_type,
+            throughput=throughput, min_vector_size=min_vector_size)
+        self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
@@ -114,8 +126,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     @property
     def endpoint(self) -> Endpoint:
+        assert self.port is not None, "Averager is not running yet"
         if self._averager_endpoint is None:
-            self._averager_endpoint = replace_port(self.listen_on, self.port if self.port is not None else '*')
+            self._averager_endpoint = f"{self.dht.get_visible_address()}:{self.port}"
             logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
         return self._averager_endpoint
 
@@ -165,48 +178,51 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         else:
             logger.warning("DHT shutdown has no effect: the process is not alive")
 
-    def step(self, allow_retries: bool = True, timeout: Optional[float] = None, wait=True
-             ) -> Union[bool, MPFuture]:
+    def step(self, allow_retries: bool = True, gather: Optional[DataForGather] = None, timeout: Optional[float] = None,
+             wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
+
         :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
           within the specified timeout
+        :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
+          (this operation is known as all-gather). The gathered data will be available as the output of this function.
         :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
         :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_step', [], dict(future=_future, allow_retries=allow_retries, timeout=timeout)))
+        self.pipe.send(('_step', [], dict(future=_future, gather=gather, allow_retries=allow_retries, timeout=timeout)))
         return future.result() if wait else future
 
-    async def _step(self, *, future: MPFuture, allow_retries: bool, timeout: Optional[float]):
+    async def _step(self, *, future: MPFuture, gather: DataForGather, allow_retries: bool, timeout: Optional[float]):
         loop = asyncio.get_event_loop()
         start_time = get_dht_time()
-
-        try_averaging = True
         group_id = None
 
-        while try_averaging:
+        while not future.done():
             try:
                 self._pending_group_assembled.clear()
-                allreduce_group = await self._matchmaking.look_for_group(timeout=timeout)
+                gather_binary = self.serializer.dumps(gather)
+                allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
                 if allreduce_group is None:
                     raise AllreduceException("Averaging step failed: could not find a group.")
 
                 group_id = allreduce_group.group_id
                 self._running_groups[group_id] = allreduce_group
                 self._pending_group_assembled.set()
-                await asyncio.wait_for(allreduce_group.run(), self.allreduce_timeout)
-                update_ok = await loop.run_in_executor(None, self.update_tensors, allreduce_group)
+                await asyncio.wait_for(allreduce_group.run(), self._allreduce_timeout)
+                await loop.run_in_executor(None, self.update_tensors, allreduce_group)
 
                 # averaging is finished, exit the loop
-                future.set_result(update_ok)
-                try_averaging = False
+                gathered_items = map(self.serializer.loads, allreduce_group.gathered)
+                gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
+                future.set_result(gathered_data_by_peer)
 
             except AllreduceException:
                 time_elapsed = get_dht_time() - start_time
                 if not allow_retries or (timeout is not None and timeout < time_elapsed):
-                    future.set_result(False)
-                    try_averaging = False
+                    future.set_result(None)
 
             except Exception as e:
                 future.set_exception(e)
@@ -215,11 +231,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 _ = self._running_groups.pop(group_id, None)
                 self._pending_group_assembled.set()
 
-    def update_tensors(self, allreduce_group: AllReduceRunner) -> bool:
+    def update_tensors(self, allreduce_group: AllReduceRunner):
         """
         a private (extendable) method that applies changes from a finished allreduce to local tensors
-
-        :return: True on success, False on failure
         """
         assert allreduce_group.return_deltas and allreduce_group.future.done()
         averaging_deltas = allreduce_group.future.result()
@@ -227,8 +241,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         with torch.no_grad(), self.get_tensors() as local_tensors:
             assert len(local_tensors) == len(self._averaged_tensors)
             for tensor, update in zip(local_tensors, averaging_deltas):
-                tensor.add_(update, alpha=self.averaging_alpha)
-            return True
+                tensor.add_(update, alpha=self._averaging_alpha)
 
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:

+ 16 - 14
hivemind/client/averaging/allreduce.py

@@ -1,5 +1,5 @@
 import asyncio
-from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Iterator
+from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any
 
 import grpc
 import torch
@@ -20,15 +20,17 @@ class AllReduceProtocol:
     :param tensors: local tensors that should be averaged with groupmates
     :param endpoint: your endpoint, must be included in ordered_group_endpoints
     :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
+    :param part_sizes: for each peer, a number of vector elements that this peer is responsible for averaging
     :param return_deltas: if True, returns the element-wise differences (averaged_tensors - original_tensors)
            default (False) - return averaged_tensors by themselves
     """
 
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], return_deltas: bool = False):
+                 ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False):
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
-        self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
-        self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, self.group_size)))
+        self.group_id, self.endpoint = group_id, endpoint
+        self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
+        self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
         self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
         self.return_deltas = return_deltas
 
@@ -121,17 +123,18 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
 
     def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
                  ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
-                 chunk_size_bytes: int, return_deltas: bool = False):
-        super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint,
+                 chunk_size_bytes: int, part_sizes: Tuple[int, ...], gathered: Sequence[Any] = (),
+                 return_deltas: bool = False):
+        super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
                          ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
-        self.compression_type, self.chunk_size_bytes = compression_type, chunk_size_bytes
+        self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
         self.averaged_part_stream: asyncio.Future[Tuple[runtime_pb2.Tensor, ...]] = asyncio.Future()
 
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
 
-    async def _average_one_part(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
-        """ Send one part of local tensors to one groupmate and collect the average for this part """
+    async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
+        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
         serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
         chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes)
 
@@ -163,7 +166,7 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
         send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
         """
         try:
-            await asyncio.gather(self, *(self._average_one_part(peer, part)
+            await asyncio.gather(self, *(self._communicate_with_peer(peer, part)
                                          for peer, part in self.local_tensor_parts.items() if peer != self.endpoint))
             return await self
         except BaseException as e:
@@ -203,6 +206,7 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
                 yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=next(averaged_chunks))
                 for averaged_chunk in averaged_chunks:
                     yield averaging_pb2.AveragingData(tensor_part=averaged_chunk)
+
             except Exception as e:
                 self.set_exception(e)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
@@ -213,12 +217,10 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
 
 
-def split_into_parts(tensors: Sequence[torch.Tensor], group_size: int) -> Tuple[torch.Tensor, ...]:
+def split_into_parts(tensors: Sequence[torch.Tensor], part_sizes: Tuple[int]) -> Tuple[torch.Tensor, ...]:
     """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
     flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
-    chunk_slices = torch.linspace(start=0, end=len(flat_tensor), steps=group_size + 1, dtype=torch.int64)
-    chunk_slices[-1] = len(flat_tensor)
-    return tuple(flat_tensor[chunk_slices[i]: chunk_slices[i + 1]] for i in range(group_size))
+    return torch.split_with_sizes(flat_tensor, part_sizes, dim=0)
 
 
 def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:

+ 98 - 0
hivemind/client/averaging/load_balancing.py

@@ -0,0 +1,98 @@
+from typing import Sequence, Optional, Tuple
+import numpy as np
+import scipy.optimize
+
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+def load_balance_peers(vector_size, throughputs: Sequence[Optional[float]], min_size: int = 0) -> Tuple[int, ...]:
+    """
+    Find an optimal partitioning of weights for butterfly all-reduce given peer throughputs.
+    :param vector_size: total size of the averaged vector (in elements, not bytes)
+    :param throughputs: 1d array of non-negative throughputs for each peer, typically min(upload speed, download speed)
+    :param min_size: peers that can aggregate less than this many elements will be assigned nothing
+    :returns: an integer array where i-th element is the number of weights assigned to i-th peer
+    """
+    specified_throughputs = [throughput for throughput in throughputs if throughput is not None and throughput > 0]
+
+    if specified_throughputs:
+        default_throughput = np.mean(specified_throughputs)
+        throughputs = [throughput if throughput is not None else default_throughput for throughput in throughputs]
+        scores = optimize_parts_lp(vector_size, np.asarray(throughputs), min_size)
+    else:
+        assert not all(throughput == 0 for throughput in throughputs), "Must have at least one nonzero throughput"
+        scores = np.asarray([1.0 if throughput is None else 0.0 for throughput in throughputs])
+
+    return tuple(hagenbach_bishoff(vector_size, scores))
+
+
+def optimize_parts_lp(vector_size: int, throughputs: np.ndarray, min_size: int = 0, eps: float = 1e-15) -> np.ndarray:
+    """
+    This method solves an optimization problem to minimize the total allreduce time.
+    In butterfly all-reduce, each peer acts both as a "client" and as an "aggregator":
+    * a "client" splits his local vector into shards and sends each shard to one peer, then downloads the average
+    * an "aggregator" receives a certain part of vector components from all peers, aggregates and returns the average
+
+    Peer i network load as a "client" = vector_size * (1 - fraction_assigned_to_peer_i)
+    Peer i network load as an "aggregator" = vector_size * (group_size - 1) * fraction_assigned_to_peer_i
+    Peer i total communication = vector_size * [1 + (group_size - 2) * fraction_assigned_to_peer_i]
+    Total time = max_i (total_communication_for_peer_i / throughputs[i])
+
+    We solve this optimization problem by reducing it to linear programming with a minimax reduction
+    (see lecture notes: https://www.usna.edu/Users/math/dphillip/sa305.s15/phillips/lessons/32/32.pdf )
+
+    :returns: a vector of "scores", i-th score is proportional to the fraction of weights assigned to i-th peer
+    """
+    assert np.all(throughputs >= 0) and np.any(throughputs > 0)
+    permutation = np.argsort(-throughputs)
+    throughputs = throughputs[permutation]
+    is_nonzero = throughputs != 0
+
+    group_size = len(throughputs)
+    num_variables = group_size + 1  # [w_1, ..., w_N, xi]
+
+    c = np.zeros(num_variables)
+    c[-1] = 1.0  # optimize w.r.t. xi
+
+    # the constraints below are tuples (A, b) such that Ax <= b
+    nonnegative_weights = -np.eye(group_size, M=num_variables), np.zeros(group_size)
+    weights_sum_to_one = c[None, :] - 1.0, np.array([-1.0])
+    coeff_per_variable = (group_size - 2.0) / np.maximum(throughputs, eps)
+    coeff_matrix_minus_xi = np.hstack([np.diag(coeff_per_variable), -np.ones((group_size, 1))])
+    xi_is_maximum = coeff_matrix_minus_xi[is_nonzero], -1.0 / throughputs[is_nonzero]
+    force_max_weights = np.eye(group_size, M=num_variables), is_nonzero.astype(c.dtype)
+
+    A, b = list(map(np.concatenate, zip(nonnegative_weights, weights_sum_to_one, xi_is_maximum, force_max_weights)))
+
+    solution = scipy.optimize.linprog(c, A_ub=A, b_ub=b)
+    if solution.success:
+        peer_scores = solution.x[:group_size]
+        # if some peers have less than min_size elements, transfer their share to other peers (if any)
+        if np.max(peer_scores) >= min_size / float(vector_size):
+            peer_scores[peer_scores < min_size / float(vector_size)] = 0.0
+    else:
+        logger.error(f"Failed to solve load-balancing for bandwidths {throughputs}.")
+        peer_scores = np.ones(group_size)
+
+    return peer_scores[np.argsort(permutation)]
+
+
+def hagenbach_bishoff(vector_size: int, scores: Sequence[float]) -> Sequence[int]:
+    """
+    Split a vector between participants based on continuous fractions.
+    https://en.wikipedia.org/wiki/Hagenbach-Bischoff_system
+    The code is based on https://github.com/crflynn/voting
+
+    :param vector_size: the total number of elements to be split
+    :param scores: real-valued vector fractions for each peer
+    :returns: integer-valued partitions assigned to every peer
+    """
+    total_score = sum(scores)
+    allocated = [int(vector_size * score_i / total_score) for score_i in scores]
+    while sum(allocated) < vector_size:
+        quotients = [score / (allocated[idx] + 1) for idx, score in enumerate(scores)]
+        idx_max = quotients.index(max(quotients))
+        allocated[idx_max] += 1
+    return allocated

+ 48 - 20
hivemind/client/averaging/matchmaking.py

@@ -6,20 +6,20 @@ import contextlib
 import random
 from dataclasses import asdict
 from math import isfinite
-from typing import Sequence, Optional, AsyncIterator, Set, Tuple
+from typing import Sequence, Optional, AsyncIterator, Set, Tuple, Dict
 import asyncio
 
-import torch
 import grpc
+import torch
 
 import hivemind
-from hivemind.client.averaging.allreduce import AllReduceRunner, GroupID
+from hivemind.client.averaging.allreduce import AllReduceRunner
+from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.dht import DHTID, DHTExpiration, get_dht_time, GroupKey
 from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, TimedStorage
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
+from hivemind.proto import averaging_pb2, averaging_pb2_grpc
 from hivemind.utils.grpc import ChannelCache
 
-
 logger = get_logger(__name__)
 
 
@@ -34,12 +34,12 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
       This deadlock only happens if averagers have outdated information on expirations (due to network delays). 
       While A->B->A deadlock is easy to fix, it gets much harder with more peers (e.g. A -> B -> C -> D -> A).
       Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
-    
     """
 
     def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *,
                  prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15, request_timeout: float, **allreduce_kwargs):
+                 averaging_expiration: float = 15, request_timeout: float, throughput: Optional[float] = None,
+                 min_vector_size: int, **allreduce_kwargs):
         assert '.' not in prefix, "group prefix must be a string without ."
         if request_timeout is None or request_timeout >= averaging_expiration:
             logger.warning("It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
@@ -50,8 +50,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         self.prefix, self.group_bits = prefix, initial_group_bits
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
+        self.throughput, self.min_vector_size = throughput, min_vector_size
         self.allreduce_kwargs = allreduce_kwargs
         self.schema_hash = compute_schema_hash(self.averaged_tensors)
+        self.total_size = sum(tensor.numel() for tensor in self.averaged_tensors)
 
         self.lock_looking_for_group = asyncio.Lock()
         self.lock_request_join_group = asyncio.Lock()
@@ -60,8 +62,9 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         self.assembled_group = asyncio.Future()
 
         self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
-        self.current_followers: Set[Endpoint] = set()  # iff i am a leader, this contains my followers excluding myself
+        self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
         self.potential_leaders = PotentialLeaders(endpoint, dht, averaging_expiration, target_group_size)
+        self.data_for_gather: bytes = None
 
     @property
     def is_looking_for_group(self):
@@ -82,8 +85,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
                f" current key = {self.current_group_key})"
 
-    async def look_for_group(self, *, timeout: Optional[float] = None) -> Optional[AllReduceRunner]:
+    async def look_for_group(self, *, data_for_gather: bytes = b'', timeout: Optional[float] = None
+                             ) -> Optional[AllReduceRunner]:
         """
+        :param gather: optionally send this data to all peers in the next group and gather it from every groupmate
+        :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         """
@@ -91,6 +97,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             logger.info("Another look_for_group is already in progress. The current run will be scheduled after"
                         " the existing group is either assembled or disbanded.")
         async with self.lock_looking_for_group:
+            self.data_for_gather = data_for_gather
             request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
             try:
                 return await asyncio.wait_for(self.assembled_group, timeout=timeout)
@@ -116,6 +123,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 # note: the code above ensures that we send all followers away before creating new future
                 self.assembled_group = asyncio.Future()
                 self.was_accepted_to_group.clear()
+                self.data_for_gather = None
 
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
         """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
@@ -161,7 +169,9 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             async with self.lock_request_join_group:
                 leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
                 call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
-                    endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time))
+                    endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time,
+                    throughput=self.throughput if self.throughput is not None else -1.0,
+                    gather=self.data_for_gather))
                 message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
 
                 if message.code == averaging_pb2.ACCEPTED:
@@ -182,8 +192,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
                     async with self.lock_request_join_group:
-                        return await self.follower_assemble_group(
-                            leader, message.group_id, message.ordered_group_endpoints)
+                        return await self.follower_assemble_group(leader, message)
 
             if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
                 if message.suggested_leader and message.suggested_leader != self.endpoint:
@@ -218,7 +227,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                     yield reason_to_reject
                     return
 
-                self.current_followers.add(request.endpoint)
+                self.current_followers[request.endpoint] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
                 if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
@@ -253,14 +262,15 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             allreduce_group = self.assembled_group.result()
             yield averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
-                ordered_group_endpoints=allreduce_group.ordered_group_endpoints)
+                ordered_group_endpoints=allreduce_group.ordered_group_endpoints,
+                part_sizes=allreduce_group.part_sizes, gathered=allreduce_group.gathered)
 
         except Exception as e:
             logger.exception(e)
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
 
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
-            self.current_followers.discard(request.endpoint)
+            self.current_followers.pop(request.endpoint, None)
             self.follower_was_discarded.set()
 
     def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> Optional[averaging_pb2.MessageFromLeader]:
@@ -297,22 +307,40 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         ordered_group_endpoints = list(self.current_followers)
         ordered_group_endpoints.append(self.endpoint)
         random.shuffle(ordered_group_endpoints)
+
+        throughputs, gathered = [], []
+        for endpoint in ordered_group_endpoints:
+            if endpoint == self.endpoint:
+                throughputs.append(self.throughput)
+                gathered.append(self.data_for_gather)
+            else:
+                follower_info = self.current_followers[endpoint]
+                throughputs.append(follower_info.throughput if follower_info.throughput >= 0 else None)
+                gathered.append(follower_info.gather if follower_info.gather else None)
+
+        part_sizes = load_balance_peers(self.total_size, throughputs, self.min_vector_size)
+
         logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
-                                          ordered_group_endpoints=ordered_group_endpoints, **self.allreduce_kwargs)
+                                          ordered_group_endpoints=ordered_group_endpoints, part_sizes=part_sizes,
+                                          gathered=gathered, **self.allreduce_kwargs)
         self.assembled_group.set_result(allreduce_group)
         return allreduce_group
 
-    async def follower_assemble_group(self, leader: Endpoint, group_id: GroupID,
-                                      ordered_group_endpoints: Sequence[Endpoint]) -> AllReduceRunner:
+    async def follower_assemble_group(self, leader: Endpoint, msg: averaging_pb2.MessageFromLeader) -> AllReduceRunner:
         """ Prepare to run allreduce using a list of peers provided by our leader """
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
         assert not self.assembled_group.done()
-        logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
+
+        group_id, ordered_group_endpoints, part_sizes = msg.group_id, msg.ordered_group_endpoints, msg.part_sizes
         assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
+        assert len(ordered_group_endpoints) == len(part_sizes) == len(msg.gathered)
+
+        logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
-                                          ordered_group_endpoints=ordered_group_endpoints, **self.allreduce_kwargs)
+                                          ordered_group_endpoints=tuple(ordered_group_endpoints),
+                                          part_sizes=tuple(part_sizes), gathered=msg.gathered, **self.allreduce_kwargs)
         self.assembled_group.set_result(allreduce_group)
         return allreduce_group
 

+ 6 - 2
hivemind/proto/averaging.proto

@@ -33,6 +33,8 @@ message JoinRequest {
   string endpoint = 1;          // A follower accepts incoming allreduce requests at this address
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
+  bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
+  float throughput = 5;         // Follower has this bandwidth for averaging (0 = default, negative = client only)
 }
 
 message MessageFromLeader {
@@ -40,11 +42,13 @@ message MessageFromLeader {
   bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
   string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
   repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
+  repeated int32 part_sizes = 5;  // a sequence of tensor parts assigned to each peer, same order as endpoints
+  repeated bytes gathered = 6;  // metadata (gather) from all groupmates in the same order as their endoints
 }
 
 message AveragingData {
   MessageCode code = 1;     // in case of a protocol violation, this will be the error message
-  bytes group_id = 2;        // a unique group identifier, same as in MessageFromLeader
+  bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
   string endpoint = 3;      // sender's rpc endpoint, used for coordination
-  Tensor tensor_part = 4;    // either peer's local tensor part (rpc input) or group average of this part (rpc output)
+  Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
 }

+ 1 - 0
requirements.txt

@@ -1,6 +1,7 @@
 PyYAML
 torch>=1.6.0
 numpy>=1.17
+scipy>=1.2.1
 prefetch_generator>=1.0.1
 msgpack>=0.5.6
 sortedcontainers

+ 76 - 5
tests/test_averaging.py

@@ -1,11 +1,12 @@
 import asyncio
 import random
-import time
 
+import numpy as np
 import torch
 import pytest
 import hivemind
 from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
+from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.utils import Endpoint
 
 
@@ -35,7 +36,7 @@ def test_getset_averagers():
 
 @pytest.mark.forked
 def test_allreduce_once():
-    dht = hivemind.DHT(start=True)
+    dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
 
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
@@ -53,7 +54,9 @@ def test_allreduce_once():
     for averager in averagers:
         futures.append(averager.step(wait=False))
     for future in futures:
-        assert future.result() is True
+        result = future.result()
+        for averager in averagers:
+            assert averager.endpoint in result
 
     for averager in averagers:
         with averager.get_tensors() as averaged_tensors:
@@ -61,6 +64,31 @@ def test_allreduce_once():
                 assert torch.allclose(ref, our, atol=1e-6)
 
 
+@pytest.mark.forked
+def test_allgather():
+    dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
+    averagers = [hivemind.DecentralizedAverager(torch.ones(1), dht=dht, target_group_size=4, averaging_expiration=15,
+                                                prefix='mygroup', initial_group_bits='000', listen_on='127.0.0.1:*',
+                                                start=True)
+                 for _ in range(8)]
+
+    futures = []
+    for i, averager in enumerate(averagers):
+        futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo='bar')))
+
+    assert len(set(repr(sorted(future.result())) for future in futures)) == 2
+
+    reference_metadata = {averager.endpoint: dict(batch_size=123 + i, foo='bar')
+                          for i, averager in enumerate(averagers)}
+    for future in futures:
+        gathered = future.result()
+
+        assert len(gathered) == 4
+
+        for endpoint in gathered:
+            assert gathered[endpoint] == reference_metadata[endpoint]
+
+
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_allreduce_protocol():
@@ -72,7 +100,8 @@ async def test_allreduce_protocol():
 
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
     allreduce_protocols = [AllReduceProtocol(
-        group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer], ordered_group_endpoints=peers)
+        group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer],
+        ordered_group_endpoints=peers, part_sizes=(150, 200, 67))
         for peer in peers]
 
     async def _accumulate(sender: Endpoint, recipient: Endpoint):
@@ -112,10 +141,52 @@ def test_partitioning():
         if total_size == 0:
             continue
         num_chunks = random.randint(1, min(1000, sum(x.numel() for x in tensors)))
-        chunks = split_into_parts(tensors, group_size=num_chunks)
+        part_sizes = load_balance_peers(total_size, [None] * num_chunks)
+        chunks = split_into_parts(tensors, part_sizes)
         assert len(chunks) == num_chunks
         shapes = [tensor.shape for tensor in tensors]
         restored = restore_from_parts(chunks, shapes)
         assert len(restored) == len(tensors)
         assert all(new.shape == old.shape for new, old in zip(restored, tensors))
         assert all(torch.allclose(new, old) for new, old in zip(restored, tensors))
+
+
+def get_cost(vector_size, partitions, throughputs):
+    return max((vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
+               for i in range(len(partitions)))
+
+
+def check_optimality(vector_size, throughputs, ref_partitions):
+    partitions = list(load_balance_peers(vector_size, throughputs))
+    assert get_cost(vector_size, partitions, throughputs) <= get_cost(vector_size, ref_partitions, throughputs)
+
+
+@pytest.mark.forked
+def test_load_balancing():
+    check_optimality(60, np.array([0.25, 0.25, 0.25, 0.25]), [15, 15, 15, 15])
+    check_optimality(1024, np.array([0.3, 0.5, 0.9]), [0, 255, 769])
+    check_optimality(60, np.array([0.44, 0.33, 0.22]), [42, 18, 0])
+    check_optimality(60, np.array([0.55, 0.44, 0.40]), [35, 16, 9])
+    check_optimality(1024 * 1024, np.array([0.3, 0.5, 0.9, 0.6]), [0, 169327, 602629, 276620])
+    check_optimality(1024 * 1024, np.array([0.0, 0.5, 0.0, 0.6]), [0, 428963, 0, 619613])
+    assert load_balance_peers(60, np.array([0.55, 0.44, 0.40]), min_size=10) == (41, 19, 0)
+    assert load_balance_peers(60, np.array([0.32, 0.55, 0.44]), min_size=10) == (0, 40, 20)
+    assert load_balance_peers(2, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 1)
+    assert load_balance_peers(1, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 0)
+
+    assert load_balance_peers(100, (None, None)) == (50, 50)
+    assert load_balance_peers(100, (None, None, None, None, None)) == (20, 20, 20, 20, 20)
+    assert load_balance_peers(100, (0, 0, 0, None, None)) == (0, 0, 0, 50, 50)
+
+    with pytest.raises(AssertionError):
+        load_balance_peers(100, (0, 0, 0))
+
+    for i in range(10):
+        vector_size = np.random.randint(1, 1024 ** 3)
+        num_peers = np.random.randint(1, 256)
+        scale = 1e-9 + np.random.rand() * 1e5
+        throughputs = np.random.rand(num_peers) * scale + 1e-6
+        min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
+        assignment = load_balance_peers(vector_size, throughputs, min_size)
+        assert np.sum(assignment) == vector_size
+        assert np.min(assignment) >= 0