Procházet zdrojové kódy

Support client-only participants in AllReduceProtocol (#176)

Resolves #147
foksly před 4 roky
rodič
revize
edf9327e45

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.9.5'
+__version__ = '0.9.6'

+ 24 - 15
hivemind/client/averaging/__init__.py

@@ -7,6 +7,7 @@ import contextlib
 import ctypes
 import multiprocessing as mp
 import threading
+import uuid
 import weakref
 from concurrent.futures.thread import ThreadPoolExecutor
 from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
@@ -95,15 +96,13 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
             "throughput must be a non-negative float32"
-        if not listen:
-            raise NotImplementedError("Client-only averaging is not implemented yet.")
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
 
         super().__init__()
         self.dht = dht
-        self.listen_on, self.receiver_threads, self.kwargs = listen_on, receiver_threads, kwargs
+        self.listen, self.listen_on, self.receiver_threads, self.kwargs = listen, listen_on, receiver_threads, kwargs
         self.channel_options = channel_options
         self.daemon = daemon
 
@@ -125,6 +124,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
         self._averager_endpoint: Optional[Endpoint] = None
+        if not self.listen:
+            self._averager_endpoint = f'client::{uuid.uuid4()}'
+
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         background_fetcher = threading.Thread(
@@ -139,9 +141,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         return self._port.value if self._port.value != 0 else None
 
     @property
-    def endpoint(self) -> Endpoint:
-        assert self.port is not None, "Averager is not running yet"
-        if self._averager_endpoint is None:
+    def endpoint(self) -> Optional[Endpoint]:
+        if self.listen and self._averager_endpoint is None:
+            assert self.port is not None, "Averager is not running yet"
             self._averager_endpoint = f"{self.dht.get_visible_address()}:{self.port}"
             logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
         return self._averager_endpoint
@@ -157,18 +159,25 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
         async def _run():
             grpc.aio.init_grpc_aio()
-            server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-            averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
-            found_port = server.add_insecure_port(self.listen_on)
-            assert found_port != 0, f"Failed to listen to {self.listen_on}"
-            self._port.value = found_port
+
+            if self.listen:
+                server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
+                averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
+                found_port = server.add_insecure_port(self.listen_on)
+                assert found_port != 0, f"Failed to listen to {self.listen_on}"
+                self._port.value = found_port
+                await server.start()
+            else:
+                logger.info(f"The averager running in an experimental client mode, please report any bugs.")
+
             self._matchmaking = Matchmaking(self.endpoint, self._averaged_tensors, self.dht, **self.matchmaking_kwargs,
-                                            return_deltas=True)  # note: we need deltas to make allreduce lock-free
+                                            client_mode=not self.listen, return_deltas=True)
+            if self.listen:
+                asyncio.create_task(self._declare_for_download_periodically())
+
             self._pending_group_assembled = asyncio.Event()
             self._pending_group_assembled.set()
-            await server.start()
             self.ready.set()
-            asyncio.create_task(self._declare_for_download_periodically())
 
             while True:
                 method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
@@ -240,7 +249,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
                 future.set_result(gathered_data_by_peer)
 
-            except (AllreduceException, MatchmakingException, asyncio.exceptions.InvalidStateError,
+            except (AllreduceException, MatchmakingException, asyncio.InvalidStateError,
                     grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
                 time_elapsed = get_dht_time() - start_time
                 if not allow_retries or (timeout is not None and timeout < time_elapsed):

+ 7 - 2
hivemind/client/averaging/allreduce.py

@@ -30,6 +30,7 @@ class AllReduceProtocol:
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         self.group_id, self.endpoint = group_id, endpoint
         self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
+        self.client_mode_endpoints = {endpoint for endpoint, size in zip(self.ordered_group_endpoints, part_sizes) if size == 0}
         self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
         self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
         self.return_deltas = return_deltas
@@ -39,6 +40,8 @@ class AllReduceProtocol:
         self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
         self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
         self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
+        for endpoint in self.client_mode_endpoints:
+            self.averaged_tensor_parts[endpoint] = torch.tensor([])
 
     def __repr__(self):
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
@@ -59,6 +62,7 @@ class AllReduceProtocol:
         assert not self.future.done(), f"already finished allreduce: {self.future}"
         assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
         assert source not in self.accumulated_from, "duplicate source, already received that part"
+        assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode"
         logger.debug(f"{self} - accumulating tensor part from {source}")
 
         self.accumulator.add_(remote_part)
@@ -172,8 +176,9 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
         send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
         """
         try:
-            await asyncio.gather(self, *(self._communicate_with_peer(peer, part)
-                                         for peer, part in self.local_tensor_parts.items() if peer != self.endpoint))
+            await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
+                                         for i, peer in enumerate(self.ordered_group_endpoints)
+                                         if peer != self.endpoint and self.part_sizes[i] > 0))
             return await self
         except BaseException as e:
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR

+ 2 - 1
hivemind/client/averaging/load_balancing.py

@@ -11,7 +11,8 @@ def load_balance_peers(vector_size, throughputs: Sequence[Optional[float]], min_
     """
     Find an optimal partitioning of weights for butterfly all-reduce given peer throughputs.
     :param vector_size: total size of the averaged vector (in elements, not bytes)
-    :param throughputs: 1d array of non-negative throughputs for each peer, typically min(upload speed, download speed)
+    :param throughputs: 1d array of non-negative throughputs for each peer capable of averaging
+      zeros stand for client-only participants, None represents "not specified" (resolved as mean of other pears)
     :param min_size: peers that can aggregate less than this many elements will be assigned nothing
     :returns: an integer array where i-th element is the number of weights assigned to i-th peer
     """

+ 29 - 21
hivemind/client/averaging/matchmaking.py

@@ -17,7 +17,7 @@ from hivemind.client.averaging.allreduce import AllReduceRunner
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.dht import DHT, DHTID, DHTExpiration, get_dht_time
-from hivemind.utils import get_logger, Endpoint, TensorDescriptor, MSGPackSerializer, timed_storage, TimedStorage
+from hivemind.utils import get_logger, Endpoint, TensorDescriptor, timed_storage, TimedStorage
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc
 from hivemind.utils.grpc import ChannelCache
 
@@ -29,19 +29,19 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
     f"""
     An internal class that is used to form groups of averages for running allreduce
     See DecentralizedAverager docstring for the detailed description of all parameters
-    
+
     :note: on implementation: the current matchmaker protocol can encounter one type of (temporary) deadlock;
       This deadlock occurs when averager A requests averager B at the same time as averager B requests averager A.
       In that case, neither averager can process the other one's request because it is awaiting lock_request_join_group.
-      This deadlock only happens if averagers have outdated information on expirations (due to network delays). 
+      This deadlock only happens if averagers have outdated information on expirations (due to network delays).
       While A->B->A deadlock is easy to fix, it gets much harder with more peers (e.g. A -> B -> C -> D -> A).
       Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
     """
 
     def __init__(self, endpoint: Endpoint, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *,
-                 prefix: str, target_group_size: int, min_group_size: int, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15, request_timeout: float, throughput: Optional[float] = None,
-                 min_vector_size: int, **allreduce_kwargs):
+                 prefix: str, target_group_size: int, min_group_size: int, min_vector_size: int,
+                 request_timeout: float, client_mode: bool, initial_group_bits: Optional[str] = None,
+                 averaging_expiration: float = 15, throughput: Optional[float] = None, **allreduce_kwargs):
         assert '.' not in prefix, "group prefix must be a string without ."
         if request_timeout is None or request_timeout >= averaging_expiration:
             logger.warning("It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
@@ -52,6 +52,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
+        self.client_mode = client_mode
         self.throughput, self.min_vector_size, self.allreduce_kwargs = throughput, min_vector_size, allreduce_kwargs
         self.schema_hash = compute_schema_hash(self.averaged_tensors)
         self.total_size = sum(tensor.numel() for tensor in self.averaged_tensors)
@@ -80,7 +81,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 lfg_status += f" leading {len(self.current_followers)} followers,"
         schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
         return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
-               f" current key = {self.group_key_manager.current_key})"
+               f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
 
     async def look_for_group(self, *, data_for_gather: bytes = b'', timeout: Optional[float] = None
                              ) -> Optional[AllReduceRunner]:
@@ -124,7 +125,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> AllReduceRunner:
         """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
-        async with self.potential_leaders.begin_search(self.group_key_manager, timeout):
+        async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
             while True:
                 try:
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
@@ -166,7 +167,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
                     endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time,
                     throughput=self.throughput if self.throughput is not None else -1.0,
-                    gather=self.data_for_gather))
+                    client_mode=self.client_mode, gather=self.data_for_gather))
                 message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
 
                 if message.code == averaging_pb2.ACCEPTED:
@@ -276,7 +277,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
         if request.ListFields() == 3 and not isinstance(request.schema_hash, bytes) or len(request.schema_hash) == 0 \
                 or not isinstance(request.expiration, DHTExpiration) or not isfinite(request.expiration) \
-                or not isinstance(request.endpoint, Endpoint) or len(request.endpoint) == 0:
+                or not isinstance(request.endpoint, Endpoint) or len(request.endpoint) == 0 or self.client_mode:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
 
         elif request.schema_hash != self.schema_hash:
@@ -297,24 +298,26 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
     async def leader_assemble_group(self) -> AllReduceRunner:
         """ Form up all current followers into a group and prepare to _run_allreduce """
-        assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
+        assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() and not self.client_mode
         assert not self.assembled_group.done()
         group_id = DHTID.generate().to_bytes()
         ordered_group_endpoints = list(self.current_followers)
         ordered_group_endpoints.append(self.endpoint)
         random.shuffle(ordered_group_endpoints)
 
-        throughputs, gathered = [], []
+        averager_throughputs, gathered = [], []
         for endpoint in ordered_group_endpoints:
             if endpoint == self.endpoint:
-                throughputs.append(self.throughput)
+                averager_throughputs.append(self.throughput)
                 gathered.append(self.data_for_gather)
             else:
                 follower_info = self.current_followers[endpoint]
-                throughputs.append(follower_info.throughput if follower_info.throughput >= 0 else None)
+                throughput = follower_info.throughput if follower_info.throughput >= 0 else None
+                averager_throughput = throughput if not follower_info.client_mode else 0.0
+                averager_throughputs.append(averager_throughput)
                 gathered.append(follower_info.gather if follower_info.gather else None)
 
-        part_sizes = load_balance_peers(self.total_size, throughputs, self.min_vector_size)
+        part_sizes = load_balance_peers(self.total_size, averager_throughputs, self.min_vector_size)
         group_key_seed = random.randint(- 2 ** 31, 2 ** 31 - 1)
 
         logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.")
@@ -331,13 +334,15 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         assert not self.assembled_group.done()
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
 
-        group_id, ordered_group_endpoints, part_sizes = msg.group_id, msg.ordered_group_endpoints, msg.part_sizes
+        group_id, ordered_group_endpoints, part_sizes = msg.group_id, tuple(msg.ordered_group_endpoints), msg.part_sizes
         assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
         assert len(ordered_group_endpoints) == len(part_sizes) == len(msg.gathered)
+        my_part_size = part_sizes[ordered_group_endpoints.index(self.endpoint)]
+        assert my_part_size == 0 or not self.client_mode, "Averager with client_mode=True cannot accept incoming data."
 
         logger.debug(f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}.")
         allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint,
-                                          ordered_group_endpoints=tuple(ordered_group_endpoints),
+                                          ordered_group_endpoints=ordered_group_endpoints,
                                           part_sizes=tuple(part_sizes), gathered=msg.gathered,
                                           group_key_seed=int(msg.group_key_seed), **self.allreduce_kwargs)
         await self.group_key_manager.update_key_on_group_assembled(allreduce_group)
@@ -346,7 +351,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
     async def leader_disband_group(self):
         """ Kick out all followers immediately, optionally direct them to our new leader (if we found one) """
-        assert self.lock_request_join_group.locked()
+        assert self.lock_request_join_group.locked() and not self.client_mode
         self.current_followers.clear()  # this will cause rpc_join_group to kick all followers out
 
 
@@ -366,19 +371,22 @@ class PotentialLeaders:
         self.search_end_time = float('inf')
 
     @contextlib.asynccontextmanager
-    async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float]):
+    async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float], declare: bool = True):
         async with self.lock_search:
             self.running.set()
             self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
-            declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
+            if declare:
+                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
+
             try:
                 yield self
             finally:
                 if not update_queue_task.done():
                     update_queue_task.cancel()
-                if not declare_averager_task.done():
+                if declare and not declare_averager_task.done():
                     declare_averager_task.cancel()
+
                 for field in (self.past_attempts, self.leader_queue, self.running,
                               self.update_finished, self.update_triggered, self.declared_expiration):
                     field.clear()

+ 2 - 1
hivemind/proto/averaging.proto

@@ -35,7 +35,8 @@ message JoinRequest {
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
-  float throughput = 5;         // Follower has this bandwidth for averaging (0 = default, negative = client only)
+  float throughput = 5;         // Follower has this bandwidth for averaging (-1 = default)
+  bool client_mode = 6;         // if True, the incoming averager is a client with no capacity for averaging
 }
 
 message MessageFromLeader {

+ 12 - 7
tests/test_averaging.py

@@ -42,9 +42,14 @@ async def test_key_manager():
 
 
 @pytest.mark.forked
-def test_allreduce_once():
+@pytest.mark.parametrize("n_client_mode_peers", [0, 2])
+def test_allreduce_once(n_client_mode_peers):
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
 
+    n_peers = 4
+    should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers)
+    random.shuffle(should_listen)
+
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
@@ -53,9 +58,9 @@ def test_allreduce_once():
     reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
 
     averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
-                                                prefix='mygroup', listen_on='127.0.0.1:*',
+                                                prefix='mygroup', listen=listen, listen_on='127.0.0.1:*',
                                                 start=True)
-                 for tensors in [tensors1, tensors2, tensors3, tensors4]]
+                 for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
 
     futures = []
     for averager in averagers:
@@ -120,7 +125,7 @@ def test_allreduce_grid():
 @pytest.mark.forked
 def test_allgather():
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
-    averagers = [hivemind.DecentralizedAverager(torch.ones(1), dht=dht, target_group_size=4, averaging_expiration=15,
+    averagers = [hivemind.DecentralizedAverager([torch.ones(1)], dht=dht, target_group_size=4, averaging_expiration=15,
                                                 prefix='mygroup', initial_group_bits='000', listen_on='127.0.0.1:*',
                                                 start=True)
                  for _ in range(8)]
@@ -150,7 +155,7 @@ def test_allgather():
 @pytest.mark.asyncio
 async def test_allreduce_protocol():
     """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
-    peers = "alice", "bob", "carol"
+    peers = "alice", "bob", "carol", "colab"
 
     tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
                        for i, peer in enumerate(peers)}
@@ -158,7 +163,7 @@ async def test_allreduce_protocol():
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
     allreduce_protocols = [AllReduceProtocol(
         group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer],
-        ordered_group_endpoints=peers, part_sizes=(150, 200, 67))
+        ordered_group_endpoints=peers, part_sizes=(150, 200, 67, 0))
         for peer in peers]
 
     async def _accumulate(sender: Endpoint, recipient: Endpoint):
@@ -169,7 +174,7 @@ async def test_allreduce_protocol():
         sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part)
 
     await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers
-                        if sender != recipient})
+                        if sender != recipient and recipient != "colab"})
 
     reference_tensors = [
         sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)