瀏覽代碼

Implement simplified all-reduce for asymmetric TCP connections (#385)

* add an option to run sequential all-reduce, wherein a runner receives results only after it finishes sending tensors
* ensure that calling stream RPC begins sending inputs right away, instead of when awaiting the first output
* verified performance improvements with colab + local machine
* simplified _make_rpc_caller
* CollaborativeOptimizer will now save local step separately in state dict

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Michael Diskin <yhn112@users.noreply.github.com>
justheuristic 3 年之前
父節點
當前提交
09985d843b

+ 43 - 7
hivemind/averaging/allreduce.py

@@ -9,7 +9,15 @@ from hivemind.compression import deserialize_torch_tensor, serialize_torch_tenso
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
-from hivemind.utils.asyncio import achain, aenumerate, afirst, amap_in_executor, anext, as_aiter
+from hivemind.utils.asyncio import (
+    achain,
+    aenumerate,
+    afirst,
+    amap_in_executor,
+    anext,
+    as_aiter,
+    attach_event_on_finished,
+)
 
 # flavour types
 GroupID = bytes
@@ -44,7 +52,10 @@ class AllReduceRunner(ServicerBase):
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param gathered: additional user-defined data collected from this group
-    :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
+    :param kwargs: additional parameters (e.g. part_size_bytes) will be passed to TensorPartContainer
+    :note: Full-mode peers send and receive tensor parts concurrently, assuming a full-duplex TCP stream. In turn,
+      non-averaging peers receive results only after they finish sending, which helps them avoid
+      throughput issues in case of asymmetric high-latency connections (e.g. ACK compression).
     """
 
     def __init__(
@@ -115,6 +126,9 @@ class AllReduceRunner(ServicerBase):
     def _get_peer_stub(self, peer: PeerID) -> StubBase:
         return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix)
 
+    def should_delay_results(self, peer_id: PeerID) -> bool:
+        return self.peer_fractions[self.ordered_peer_ids.index(peer_id)] == 0
+
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         pending_tasks = set()
@@ -155,7 +169,7 @@ class AllReduceRunner(ServicerBase):
 
         else:
             code = None
-            stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
+            stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
             async for part_index, (averaged_part_delta, msg) in aenumerate(
                 amap_in_executor(
                     lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
@@ -199,8 +213,31 @@ class AllReduceRunner(ServicerBase):
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
                 sender_index = self.sender_peer_ids.index(context.remote_id)
-                async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
-                    yield msg
+
+                if not self.should_delay_results(context.remote_id):
+                    async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
+                        yield msg
+
+                else:
+                    done_receiving = asyncio.Event()
+                    delayed_results = asyncio.Queue()
+
+                    async def _accumulate_parts():
+                        inputs_aiter = attach_event_on_finished(achain(as_aiter(request), stream), done_receiving)
+                        async for msg in self._accumulate_parts_streaming(inputs_aiter, sender_index):
+                            delayed_results.put_nowait(msg)
+                        delayed_results.put_nowait(None)
+
+                    accumulate_task = asyncio.create_task(_accumulate_parts())
+
+                    await done_receiving.wait()
+
+                    while True:
+                        next_result = await delayed_results.get()
+                        if next_result is None:
+                            break
+                        yield next_result
+                    await accumulate_task
 
             except Exception as e:
                 self.finalize(exception=e)
@@ -239,8 +276,7 @@ class AllReduceRunner(ServicerBase):
 
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
         error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-        # Coroutines are lazy, so we take the first item to start the couroutine's execution
-        await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+        await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 1 - 1
hivemind/averaging/averager.py

@@ -609,7 +609,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     logger.info(f"Downloading parameters from peer {peer}")
                     try:
                         stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
-                        stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
+                        stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
 
                         async for message in aiter_with_timeout(stream, timeout=self.request_timeout):

+ 1 - 1
hivemind/averaging/matchmaking.py

@@ -180,7 +180,7 @@ class Matchmaking:
             async with self.lock_request_join_group:
                 leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
 
-                stream = leader_stub.rpc_join_group(
+                stream = await leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
                         schema_hash=self.schema_hash,
                         expiration=expiration_time,

+ 1 - 1
hivemind/averaging/partition.py

@@ -13,7 +13,7 @@ from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor
 
 T = TypeVar("T")
-DEFAULT_PART_SIZE_BYTES = 2 ** 16
+DEFAULT_PART_SIZE_BYTES = 2 ** 19
 
 
 class TensorPartContainer:

+ 10 - 0
hivemind/optim/collaborative.py

@@ -210,6 +210,16 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.reset_accumulated_grads_()
             self.update_scheduler()
 
+    def state_dict(self) -> dict:
+        state_dict = super().state_dict()
+        state_dict["state"]["collaborative_step"] = self.local_step
+        return state_dict
+
+    def load_state_dict(self, state_dict: dict):
+        if "collaborative_step" in state_dict["state"]:
+            self.averager.local_step = state_dict["state"].pop("collaborative_step")
+        return super().load_state_dict(state_dict)
+
     def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
         """
         Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters

+ 20 - 17
hivemind/p2p/p2p_daemon.py

@@ -386,22 +386,25 @@ class P2P:
                 await P2P.send_protobuf(request, writer)
             await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
 
-        with closing(writer):
-            writing_task = asyncio.create_task(_write_to_stream())
-            try:
-                while True:
-                    try:
-                        response, err = await P2P.receive_protobuf(output_protobuf_type, reader)
-                    except asyncio.IncompleteReadError:  # Connection is closed
-                        break
+        async def _read_from_stream() -> AsyncIterator[Message]:
+            with closing(writer):
+                try:
+                    while True:
+                        try:
+                            response, err = await P2P.receive_protobuf(output_protobuf_type, reader)
+                        except asyncio.IncompleteReadError:  # Connection is closed
+                            break
 
-                    if err is not None:
-                        raise P2PHandlerError(f"Failed to call handler `{name}` at {peer_id}: {err.message}")
-                    yield response
+                        if err is not None:
+                            raise P2PHandlerError(f"Failed to call handler `{name}` at {peer_id}: {err.message}")
+                        yield response
+
+                    await writing_task
+                finally:
+                    writing_task.cancel()
 
-                await writing_task
-            finally:
-                writing_task.cancel()
+        writing_task = asyncio.create_task(_write_to_stream())
+        return _read_from_stream()
 
     async def add_protobuf_handler(
         self,
@@ -476,7 +479,7 @@ class P2P:
         if not isinstance(input, AsyncIterableABC):
             return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
 
-        responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
+        responses = await self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
         return await asingle(responses)
 
     async def _call_unary_protobuf_handler(
@@ -490,7 +493,7 @@ class P2P:
         response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
         return output_protobuf_type.FromString(response)
 
-    def iterate_protobuf_handler(
+    async def iterate_protobuf_handler(
         self,
         peer_id: PeerID,
         name: str,
@@ -498,7 +501,7 @@ class P2P:
         output_protobuf_type: Type[Message],
     ) -> TOutputStream:
         requests = input if isinstance(input, AsyncIterableABC) else as_aiter(input)
-        return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+        return await self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
 
     def _start_listening(self) -> None:
         async def listen() -> None:

+ 9 - 26
hivemind/p2p/servicer.py

@@ -86,38 +86,21 @@ class ServicerBase:
     @classmethod
     def _make_rpc_caller(cls, handler: RPCHandler):
         input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
+        output_type = AsyncIterator[handler.response_type] if handler.stream_output else handler.response_type
 
         # This method will be added to a new Stub type (a subclass of StubBase)
-        if handler.stream_output:
-
-            def caller(
-                self: StubBase, input: input_type, timeout: None = None
-            ) -> AsyncIterator[handler.response_type]:
-                if timeout is not None:
-                    raise ValueError("Timeouts for handlers returning streams are not supported")
-
-                return self._p2p.iterate_protobuf_handler(
-                    self._peer,
-                    cls._get_handle_name(self._namespace, handler.method_name),
-                    input,
-                    handler.response_type,
-                )
-
-        else:
-
-            async def caller(
-                self: StubBase, input: input_type, timeout: Optional[float] = None
-            ) -> handler.response_type:
+        async def caller(self: StubBase, input: input_type, timeout: Optional[float] = None) -> output_type:
+            handle_name = cls._get_handle_name(self._namespace, handler.method_name)
+            if not handler.stream_output:
                 return await asyncio.wait_for(
-                    self._p2p.call_protobuf_handler(
-                        self._peer,
-                        cls._get_handle_name(self._namespace, handler.method_name),
-                        input,
-                        handler.response_type,
-                    ),
+                    self._p2p.call_protobuf_handler(self._peer, handle_name, input, handler.response_type),
                     timeout=timeout,
                 )
 
+            if timeout is not None:
+                raise ValueError("Timeouts for handlers returning streams are not supported")
+            return await self._p2p.iterate_protobuf_handler(self._peer, handle_name, input, handler.response_type)
+
         caller.__name__ = handler.method_name
         return caller
 

+ 9 - 0
hivemind/utils/asyncio.py

@@ -138,3 +138,12 @@ async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: float) -> Asyn
             yield await asyncio.wait_for(iterator.__anext__(), timeout=timeout)
         except StopAsyncIteration:
             break
+
+
+async def attach_event_on_finished(iterable: AsyncIterable[T], event: asyncio.Event()) -> AsyncIterator[T]:
+    """Iterate over an async iterable and set an event when the iteration has stopped, failed or terminated"""
+    try:
+        async for item in iterable:
+            yield item
+    finally:
+        event.set()

+ 5 - 3
tests/test_p2p_servicer.py

@@ -68,8 +68,9 @@ async def test_unary_stream(server_client):
     await servicer.add_p2p_handlers(server)
     stub = ExampleServicer.get_stub(client, server.peer_id)
 
+    stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
     i = 0
-    async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
+    async for item in stream:
         assert item == test_pb2.TestResponse(number=i)
         i += 1
     assert i == 10
@@ -94,8 +95,9 @@ async def test_stream_stream(server_client):
         for i in range(10):
             yield test_pb2.TestRequest(number=i)
 
+    stream = await stub.rpc_powers(generate_requests())
     i = 0
-    async for item in stub.rpc_powers(generate_requests()):
+    async for item in stream:
         if i % 2 == 0:
             assert item == test_pb2.TestResponse(number=(i // 2) ** 2)
         else:
@@ -140,7 +142,7 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
         writer.close()
     elif cancel_reason == "close_generator":
         stub = ExampleServicer.get_stub(client, server.peer_id)
-        iter = stub.rpc_wait(test_pb2.TestRequest(number=10))
+        iter = await stub.rpc_wait(test_pb2.TestRequest(number=10))
 
         assert await anext(iter) == test_pb2.TestResponse(number=11)
         await asyncio.sleep(0.25)

+ 13 - 0
tests/test_util_modules.py

@@ -23,6 +23,7 @@ from hivemind.utils.asyncio import (
     anext,
     as_aiter,
     asingle,
+    attach_event_on_finished,
     azip,
     cancel_and_wait,
 )
@@ -490,6 +491,18 @@ async def test_asyncio_utils():
 
     assert num_steps == 2
 
+    event = asyncio.Event()
+    async for i in attach_event_on_finished(iterate_with_delays([0, 0, 0, 0, 0]), event):
+        assert not event.is_set()
+    assert event.is_set()
+
+    event = asyncio.Event()
+    sleepy_aiter = iterate_with_delays([0.1, 0.1, 0.3, 0.1, 0.1])
+    with pytest.raises(asyncio.TimeoutError):
+        async for _ in attach_event_on_finished(aiter_with_timeout(sleepy_aiter, timeout=0.2), event):
+            assert not event.is_set()
+    assert event.is_set()
+
 
 @pytest.mark.asyncio
 async def test_cancel_and_wait():