Эх сурвалжийг харах

Support client-only participants in AllReduceProtocol (#176)

Resolves #147
foksly 4 жил өмнө
parent
commit
edf9327e45

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.9.5'
+__version__ = '0.9.6'

+ 24 - 15
hivemind/client/averaging/__init__.py

@@ -7,6 +7,7 @@ import contextlib
 import ctypes
 import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import threading
 import threading
+import uuid
 import weakref
 import weakref
 from concurrent.futures.thread import ThreadPoolExecutor
 from concurrent.futures.thread import ThreadPoolExecutor
 from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
@@ -95,15 +96,13 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
         assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
             "throughput must be a non-negative float32"
             "throughput must be a non-negative float32"
-        if not listen:
-            raise NotImplementedError("Client-only averaging is not implemented yet.")
         if not is_power_of_two(target_group_size):
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
             logger.warning("It is recommended to set target_group_size to a power of 2.")
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
 
 
         super().__init__()
         super().__init__()
         self.dht = dht
         self.dht = dht
-        self.listen_on, self.receiver_threads, self.kwargs = listen_on, receiver_threads, kwargs
+        self.listen, self.listen_on, self.receiver_threads, self.kwargs = listen, listen_on, receiver_threads, kwargs
         self.channel_options = channel_options
         self.channel_options = channel_options
         self.daemon = daemon
         self.daemon = daemon
 
 
@@ -125,6 +124,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
         self._averager_endpoint: Optional[Endpoint] = None
         self._averager_endpoint: Optional[Endpoint] = None
+        if not self.listen:
+            self._averager_endpoint = f'client::{uuid.uuid4()}'
+
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         background_fetcher = threading.Thread(
         background_fetcher = threading.Thread(
@@ -139,9 +141,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         return self._port.value if self._port.value != 0 else None
         return self._port.value if self._port.value != 0 else None
 
 
     @property
     @property
-    def endpoint(self) -> Endpoint:
-        assert self.port is not None, "Averager is not running yet"
-        if self._averager_endpoint is None:
+    def endpoint(self) -> Optional[Endpoint]:
+        if self.listen and self._averager_endpoint is None:
+            assert self.port is not None, "Averager is not running yet"
             self._averager_endpoint = f"{self.dht.get_visible_address()}:{self.port}"
             self._averager_endpoint = f"{self.dht.get_visible_address()}:{self.port}"
             logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
             logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
         return self._averager_endpoint
         return self._averager_endpoint
@@ -157,18 +159,25 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
         async def _run():
         async def _run():
             grpc.aio.init_grpc_aio()
             grpc.aio.init_grpc_aio()
-            server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-            averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
-            found_port = server.add_insecure_port(self.listen_on)
-            assert found_port != 0, f"Failed to listen to {self.listen_on}"
-            self._port.value = found_port
+
+            if self.listen:
+                server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
+                averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
+                found_port = server.add_insecure_port(self.listen_on)
+                assert found_port != 0, f"Failed to listen to {self.listen_on}"
+                self._port.value = found_port
+                await server.start()
+            else:
+                logger.info(f"The averager running in an experimental client mode, please report any bugs.")
+
             self._matchmaking = Matchmaking(self.endpoint, self._averaged_tensors, self.dht, **self.matchmaking_kwargs,
             self._matchmaking = Matchmaking(self.endpoint, self._averaged_tensors, self.dht, **self.matchmaking_kwargs,
-                                            return_deltas=True)  # note: we need deltas to make allreduce lock-free
+                                            client_mode=not self.listen, return_deltas=True)
+            if self.listen:
+                asyncio.create_task(self._declare_for_download_periodically())
+
             self._pending_group_assembled = asyncio.Event()
             self._pending_group_assembled = asyncio.Event()
             self._pending_group_assembled.set()
             self._pending_group_assembled.set()
-            await server.start()
             self.ready.set()
             self.ready.set()
-            asyncio.create_task(self._declare_for_download_periodically())
 
 
             while True:
             while True:
                 method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
                 method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
@@ -240,7 +249,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
                 gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
                 future.set_result(gathered_data_by_peer)
                 future.set_result(gathered_data_by_peer)
 
 
-            except (AllreduceException, MatchmakingException, asyncio.exceptions.InvalidStateError,
+            except (AllreduceException, MatchmakingException, asyncio.InvalidStateError,
                     grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
                     grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
                 time_elapsed = get_dht_time() - start_time
                 time_elapsed = get_dht_time() - start_time
                 if not allow_retries or (timeout is not None and timeout < time_elapsed):
                 if not allow_retries or (timeout is not None and timeout < time_elapsed):

+ 7 - 2
hivemind/client/averaging/allreduce.py

@@ -30,6 +30,7 @@ class AllReduceProtocol:
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         self.group_id, self.endpoint = group_id, endpoint
         self.group_id, self.endpoint = group_id, endpoint
         self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
         self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
+        self.client_mode_endpoints = {endpoint for endpoint, size in zip(self.ordered_group_endpoints, part_sizes) if size == 0}
         self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
         self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
         self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
         self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
         self.return_deltas = return_deltas
         self.return_deltas = return_deltas
@@ -39,6 +40,8 @@ class AllReduceProtocol:
         self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
         self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
         self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
         self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
         self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
         self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
+        for endpoint in self.client_mode_endpoints:
+            self.averaged_tensor_parts[endpoint] = torch.tensor([])
 
 
     def __repr__(self):
     def __repr__(self):
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
@@ -59,6 +62,7 @@ class AllReduceProtocol:
         assert not self.future.done(), f"already finished allreduce: {self.future}"
         assert not self.future.done(), f"already finished allreduce: {self.future}"
         assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
         assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
         assert source not in self.accumulated_from, "duplicate source, already received that part"
         assert source not in self.accumulated_from, "duplicate source, already received that part"
+        assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode"
         logger.debug(f"{self} - accumulating tensor part from {source}")
         logger.debug(f"{self} - accumulating tensor part from {source}")
 
 
         self.accumulator.add_(remote_part)
         self.accumulator.add_(remote_part)
@@ -172,8 +176,9 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
         send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
         send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
         """
         """
         try:
         try:
-            await asyncio.gather(self, *(self._communicate_with_peer(peer, part)
-                                         for peer, part in self.local_tensor_parts.items() if peer != self.endpoint))
+            await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
+                                         for i, peer in enumerate(self.ordered_group_endpoints)
+                                         if peer != self.endpoint and self.part_sizes[i] > 0))
             return await self
             return await self
         except BaseException as e:
         except BaseException as e:
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR

+ 2 - 1
hivemind/client/averaging/load_balancing.py

@@ -11,7 +11,8 @@ def load_balance_peers(vector_size, throughputs: Sequence[Optional[float]], min_
     """
     """
     Find an optimal partitioning of weights for butterfly all-reduce given peer throughputs.
     Find an optimal partitioning of weights for butterfly all-reduce given peer throughputs.
     :param vector_size: total size of the averaged vector (in elements, not bytes)
     :param vector_size: total size of the averaged vector (in elements, not bytes)
-    :param throughputs: 1d array of non-negative throughputs for each peer, typically min(upload speed, download speed)
+    :param throughputs: 1d array of non-negative throughputs for each peer capable of averaging
+      zeros stand for client-only participants, None represents "not specified" (resolved as mean of other pears)
     :param min_size: peers that can aggregate less than this many elements will be assigned nothing
     :param min_size: peers that can aggregate less than this many elements will be assigned nothing
     :returns: an integer array where i-th element is the number of weights assigned to i-th peer
     :returns: an integer array where i-th element is the number of weights assigned to i-th peer
     """
     """

+ 29 - 21
hivemind/client/averaging/matchmaking.py

@@ -17,7 +17,7 @@ from hivemind.client.averaging.allreduce import AllReduceRunner
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.dht import DHT, DHTID, DHTExpiration, get_dht_time
 from hivemind.dht import DHT, DHTID, DHTExpiration, get_dht_time
-from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, timed_storage, TimedStorage
+from hivemind.utils import get_logger, Endpoint, TensorDescriptor, timed_storage, TimedStorage
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc
 from hivemind.utils.grpc import ChannelCache
 from hivemind.utils.grpc import ChannelCache
 
 
@@ -29,19 +29,19 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
     f"""
     f"""
     An internal class that is used to form groups of averages for running allreduce
     An internal class that is used to form groups of averages for running allreduce
     See DecentralizedAverager docstring for the detailed description of all parameters
     See DecentralizedAverager docstring for the detailed description of all parameters
-    
+
     :note: on implementation: the current matchmaker protocol can encounter one type of (temporary) deadlock;
     :note: on implementation: the current matchmaker protocol can encounter one type of (temporary) deadlock;
       This deadlock occurs when averager A requests averager B at the same time as averager B requests averager A.
       This deadlock occurs when averager A requests averager B at the same time as averager B requests averager A.
       In that case, neither averager can process the other one's request because it is awaiting lock_request_join_group.
       In that case, neither averager can process the other one's request because it is awaiting lock_request_join_group.
-      This deadlock only happens if averagers have outdated information on expirations (due to network delays). 
+      This deadlock only happens if averagers have outdated information on expirations (due to network delays).
       While A->B->A deadlock is easy to fix, it gets much harder with more peers (e.g. A -> B -> C -> D -> A).
       While A->B->A deadlock is easy to fix, it gets much harder with more peers (e.g. A -> B -> C -> D -> A).
       Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
       Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
     """
     """
 
 
     def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *,
     def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *,
-                 prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15, request_timeout: float, throughput: Optional[float] = None,
-                 min_vector_size: int, **allreduce_kwargs):
+                 prefix: str, target_group_size: int, min_group_size: int, min_vector_size: int,
+                 request_timeout: float, client_mode: bool, initial_group_bits: Optional[str] = None,
+                 averaging_expiration: float = 15, throughput: Optional[float] = None, **allreduce_kwargs):
         assert '.' not in prefix, "group prefix must be a string without ."
         assert '.' not in prefix, "group prefix must be a string without ."
         if request_timeout is None or request_timeout >= averaging_expiration:
         if request_timeout is None or request_timeout >= averaging_expiration:
             logger.warning("It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
             logger.warning("It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
@@ -52,6 +52,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size)
         self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
+        self.client_mode = client_mode
         self.throughput, self.min_vector_size, self.allreduce_kwargs = throughput, min_vector_size, allreduce_kwargs
         self.throughput, self.min_vector_size, self.allreduce_kwargs = throughput, min_vector_size, allreduce_kwargs
         self.schema_hash = compute_schema_hash(self.averaged_tensors)
         self.schema_hash = compute_schema_hash(self.averaged_tensors)
         self.total_size = sum(tensor.numel() for tensor in self.averaged_tensors)
         self.total_size = sum(tensor.numel() for tensor in self.averaged_tensors)
@@ -80,7 +81,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 lfg_status += f" leading {len(self.current_followers)} followers,"
                 lfg_status += f" leading {len(self.current_followers)} followers,"
         schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
         schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
         return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
         return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
-               f" current key = {self.group_key_manager.current_key})"
+               f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
 
 
     async def look_for_group(self, *, data_for_gather: bytes = b'', timeout: Optional[float] = None
     async def look_for_group(self, *, data_for_gather: bytes = b'', timeout: Optional[float] = None
                              ) -> Optional[AllReduceRunner]:
                              ) -> Optional[AllReduceRunner]:
@@ -124,7 +125,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
         """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
         """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
-        async with self.potential_leaders.begin_search(self.group_key_manager, timeout):
+        async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
             while True:
             while True:
                 try:
                 try:
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
@@ -166,7 +167,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
                 call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
                     endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time,
                     endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time,
                     throughput=self.throughput if self.throughput is not None else -1.0,
                     throughput=self.throughput if self.throughput is not None else -1.0,
-                    gather=self.data_for_gather))
+                    client_mode=self.client_mode, gather=self.data_for_gather))
                 message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
                 message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
 
 
                 if message.code == averaging_pb2.ACCEPTED:
                 if message.code == averaging_pb2.ACCEPTED:
@@ -276,7 +277,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
         if request.ListFields() == 3 and not isinstance(request.schema_hash, bytes) or len(request.schema_hash) == 0 \
         if request.ListFields() == 3 and not isinstance(request.schema_hash, bytes) or len(request.schema_hash) == 0 \
                 or not isinstance(request.expiration, DHTExpiration) or not isfinite(request.expiration) \
                 or not isinstance(request.expiration, DHTExpiration) or not isfinite(request.expiration) \
-                or not isinstance(request.endpoint, Endpoint) or len(request.endpoint) == 0:
+                or not isinstance(request.endpoint, Endpoint) or len(request.endpoint) == 0 or self.client_mode:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
 
 
         elif request.schema_hash != self.schema_hash:
         elif request.schema_hash != self.schema_hash:
@@ -297,24 +298,26 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
     async def leader_assemble_group(self) -> AllReduceRunner:
     async def leader_assemble_group(self) -> AllReduceRunner:
         """ Form up all current followers into a group and prepare to _run_allreduce """
         """ Form up all current followers into a group and prepare to _run_allreduce """
-        assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
+        assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() and not self.client_mode
         assert not self.assembled_group.done()
         assert not self.assembled_group.done()
         group_id = DHTID.generate().to_bytes()
         group_id = DHTID.generate().to_bytes()
         ordered_group_endpoints = list(self.current_followers)
         ordered_group_endpoints = list(self.current_followers)
         ordered_group_endpoints.append(self.endpoint)
         ordered_group_endpoints.append(self.endpoint)
         random.shuffle(ordered_group_endpoints)
         random.shuffle(ordered_group_endpoints)
 
 
-        throughputs, gathered = [], []
+        averager_throughputs, gathered = [], []
         for endpoint in ordered_group_endpoints:
         for endpoint in ordered_group_endpoints:
             if endpoint == self.endpoint:
             if endpoint == self.endpoint:
-                throughputs.append(self.throughput)
+                averager_throughputs.append(self.throughput)
                 gathered.append(self.data_for_gather)
                 gathered.append(self.data_for_gather)
             else:
             else:
                 follower_info = self.current_followers[endpoint]
                 follower_info = self.current_followers[endpoint]
-                throughputs.append(follower_info.throughput if follower_info.throughput >= 0 else None)
+                throughput = follower_info.throughput if follower_info.throughput >= 0 else None
+                averager_throughput = throughput if not follower_info.client_mode else 0.0
+                averager_throughputs.append(averager_throughput)
                 gathered.append(follower_info.gather if follower_info.gather else None)
                 gathered.append(follower_info.gather if follower_info.gather else None)
 
 
-        part_sizes = load_balance_peers(self.total_size, throughputs, self.min_vector_size)
+        part_sizes = load_balance_peers(self.total_size, averager_throughputs, self.min_vector_size)
         group_key_seed = random.randint(- 2 ** 31, 2 ** 31 - 1)
         group_key_seed = random.randint(- 2 ** 31, 2 ** 31 - 1)
 
 
         logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
         logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
@@ -331,13 +334,15 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         assert not self.assembled_group.done()
         assert not self.assembled_group.done()
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
 
 
-        group_id, ordered_group_endpoints, part_sizes = msg.group_id, msg.ordered_group_endpoints, msg.part_sizes
+        group_id, ordered_group_endpoints, part_sizes = msg.group_id, tuple(msg.ordered_group_endpoints), msg.part_sizes
         assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
         assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
         assert len(ordered_group_endpoints) == len(part_sizes) == len(msg.gathered)
         assert len(ordered_group_endpoints) == len(part_sizes) == len(msg.gathered)
+        my_part_size = part_sizes[ordered_group_endpoints.index(self.endpoint)]
+        assert my_part_size == 0 or not self.client_mode, "Averager with client_mode=True cannot accept incoming data."
 
 
         logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
         logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
-                                          ordered_group_endpoints=tuple(ordered_group_endpoints),
+                                          ordered_group_endpoints=ordered_group_endpoints,
                                           part_sizes=tuple(part_sizes), gathered=msg.gathered,
                                           part_sizes=tuple(part_sizes), gathered=msg.gathered,
                                           group_key_seed=int(msg.group_key_seed), **self.allreduce_kwargs)
                                           group_key_seed=int(msg.group_key_seed), **self.allreduce_kwargs)
         await self.group_key_manager.update_key_on_group_assembled(allreduce_group)
         await self.group_key_manager.update_key_on_group_assembled(allreduce_group)
@@ -346,7 +351,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
     async def leader_disband_group(self):
     async def leader_disband_group(self):
         """ Kick out all followers immediately, optionally direct them to our new leader (if we found one) """
         """ Kick out all followers immediately, optionally direct them to our new leader (if we found one) """
-        assert self.lock_request_join_group.locked()
+        assert self.lock_request_join_group.locked() and not self.client_mode
         self.current_followers.clear()  # this will cause rpc_join_group to kick all followers out
         self.current_followers.clear()  # this will cause rpc_join_group to kick all followers out
 
 
 
 
@@ -366,19 +371,22 @@ class PotentialLeaders:
         self.search_end_time = float('inf')
         self.search_end_time = float('inf')
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
-    async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float]):
+    async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float], declare: bool = True):
         async with self.lock_search:
         async with self.lock_search:
             self.running.set()
             self.running.set()
             self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
             self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
-            declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
+            if declare:
+                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
+
             try:
             try:
                 yield self
                 yield self
             finally:
             finally:
                 if not update_queue_task.done():
                 if not update_queue_task.done():
                     update_queue_task.cancel()
                     update_queue_task.cancel()
-                if not declare_averager_task.done():
+                if declare and not declare_averager_task.done():
                     declare_averager_task.cancel()
                     declare_averager_task.cancel()
+
                 for field in (self.past_attempts, self.leader_queue, self.running,
                 for field in (self.past_attempts, self.leader_queue, self.running,
                               self.update_finished, self.update_triggered, self.declared_expiration):
                               self.update_finished, self.update_triggered, self.declared_expiration):
                     field.clear()
                     field.clear()

+ 2 - 1
hivemind/proto/averaging.proto

@@ -35,7 +35,8 @@ message JoinRequest {
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
   bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
-  float throughput = 5;         // Follower has this bandwidth for averaging (0 = default, negative = client only)
+  float throughput = 5;         // Follower has this bandwidth for averaging (-1 = default)
+  bool client_mode = 6;         // if True, the incoming averager is a client with no capacity for averaging
 }
 }
 
 
 message MessageFromLeader {
 message MessageFromLeader {

+ 12 - 7
tests/test_averaging.py

@@ -42,9 +42,14 @@ async def test_key_manager():
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_allreduce_once():
+@pytest.mark.parametrize("n_client_mode_peers", [0, 2])
+def test_allreduce_once(n_client_mode_peers):
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
 
 
+    n_peers = 4
+    should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers)
+    random.shuffle(should_listen)
+
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
@@ -53,9 +58,9 @@ def test_allreduce_once():
     reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
     reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
 
 
     averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
     averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
-                                                prefix='mygroup', listen_on='127.0.0.1:*',
+                                                prefix='mygroup', listen=listen, listen_on='127.0.0.1:*',
                                                 start=True)
                                                 start=True)
-                 for tensors in [tensors1, tensors2, tensors3, tensors4]]
+                 for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
 
 
     futures = []
     futures = []
     for averager in averagers:
     for averager in averagers:
@@ -120,7 +125,7 @@ def test_allreduce_grid():
 @pytest.mark.forked
 @pytest.mark.forked
 def test_allgather():
 def test_allgather():
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
-    averagers = [hivemind.DecentralizedAverager(torch.ones(1), dht=dht, target_group_size=4, averaging_expiration=15,
+    averagers = [hivemind.DecentralizedAverager([torch.ones(1)], dht=dht, target_group_size=4, averaging_expiration=15,
                                                 prefix='mygroup', initial_group_bits='000', listen_on='127.0.0.1:*',
                                                 prefix='mygroup', initial_group_bits='000', listen_on='127.0.0.1:*',
                                                 start=True)
                                                 start=True)
                  for _ in range(8)]
                  for _ in range(8)]
@@ -150,7 +155,7 @@ def test_allgather():
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_allreduce_protocol():
 async def test_allreduce_protocol():
     """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
     """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
-    peers = "alice", "bob", "carol"
+    peers = "alice", "bob", "carol", "colab"
 
 
     tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
     tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
                        for i, peer in enumerate(peers)}
                        for i, peer in enumerate(peers)}
@@ -158,7 +163,7 @@ async def test_allreduce_protocol():
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
     allreduce_protocols = [AllReduceProtocol(
     allreduce_protocols = [AllReduceProtocol(
         group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer],
         group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer],
-        ordered_group_endpoints=peers, part_sizes=(150, 200, 67))
+        ordered_group_endpoints=peers, part_sizes=(150, 200, 67, 0))
         for peer in peers]
         for peer in peers]
 
 
     async def _accumulate(sender: Endpoint, recipient: Endpoint):
     async def _accumulate(sender: Endpoint, recipient: Endpoint):
@@ -169,7 +174,7 @@ async def test_allreduce_protocol():
         sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part)
         sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part)
 
 
     await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers
     await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers
-                        if sender != recipient})
+                        if sender != recipient and recipient != "colab"})
 
 
     reference_tensors = [
     reference_tensors = [
         sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
         sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)