Pārlūkot izejas kodu

Implement simplified all-reduce for asymmetric TCP connections (#385)

* add an option to run sequential all-reduce, wherein a runner receives results only after it finishes sending tensors
* ensure that calling stream RPC begins sending inputs right away, instead of when awaiting the first output
* verified performance improvements with colab + local machine
* simplified _make_rpc_caller
* CollaborativeOptimizer will now save local step separately in state dict

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Michael Diskin <yhn112@users.noreply.github.com>
justheuristic 3 gadi atpakaļ
vecāks
revīzija
09985d843b

+ 43 - 7
hivemind/averaging/allreduce.py

@@ -9,7 +9,15 @@ from hivemind.compression import deserialize_torch_tensor, serialize_torch_tenso
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
-from hivemind.utils.asyncio import achain, aenumerate, afirst, amap_in_executor, anext, as_aiter
+from hivemind.utils.asyncio import (
+    achain,
+    aenumerate,
+    afirst,
+    amap_in_executor,
+    anext,
+    as_aiter,
+    attach_event_on_finished,
+)
 
 
 # flavour types
 # flavour types
 GroupID = bytes
 GroupID = bytes
@@ -44,7 +52,10 @@ class AllReduceRunner(ServicerBase):
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param gathered: additional user-defined data collected from this group
     :param gathered: additional user-defined data collected from this group
-    :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
+    :param kwargs: additional parameters (e.g. part_size_bytes) will be passed to TensorPartContainer
+    :note: Full-mode peers send and receive tensor parts concurrently, assuming a full-duplex TCP stream. In turn,
+      non-averaging peers receive results only after they finish sending, which helps them avoid
+      throughput issues in case of asymmetric high-latency connections (e.g. ACK compression).
     """
     """
 
 
     def __init__(
     def __init__(
@@ -115,6 +126,9 @@ class AllReduceRunner(ServicerBase):
     def _get_peer_stub(self, peer: PeerID) -> StubBase:
     def _get_peer_stub(self, peer: PeerID) -> StubBase:
         return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix)
         return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix)
 
 
+    def should_delay_results(self, peer_id: PeerID) -> bool:
+        return self.peer_fractions[self.ordered_peer_ids.index(peer_id)] == 0
+
     async def run(self) -> AsyncIterator[torch.Tensor]:
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         pending_tasks = set()
         pending_tasks = set()
@@ -155,7 +169,7 @@ class AllReduceRunner(ServicerBase):
 
 
         else:
         else:
             code = None
             code = None
-            stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
+            stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
             async for part_index, (averaged_part_delta, msg) in aenumerate(
             async for part_index, (averaged_part_delta, msg) in aenumerate(
                 amap_in_executor(
                 amap_in_executor(
                     lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
                     lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
@@ -199,8 +213,31 @@ class AllReduceRunner(ServicerBase):
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
             try:
                 sender_index = self.sender_peer_ids.index(context.remote_id)
                 sender_index = self.sender_peer_ids.index(context.remote_id)
-                async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
-                    yield msg
+
+                if not self.should_delay_results(context.remote_id):
+                    async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
+                        yield msg
+
+                else:
+                    done_receiving = asyncio.Event()
+                    delayed_results = asyncio.Queue()
+
+                    async def _accumulate_parts():
+                        inputs_aiter = attach_event_on_finished(achain(as_aiter(request), stream), done_receiving)
+                        async for msg in self._accumulate_parts_streaming(inputs_aiter, sender_index):
+                            delayed_results.put_nowait(msg)
+                        delayed_results.put_nowait(None)
+
+                    accumulate_task = asyncio.create_task(_accumulate_parts())
+
+                    await done_receiving.wait()
+
+                    while True:
+                        next_result = await delayed_results.get()
+                        if next_result is None:
+                            break
+                        yield next_result
+                    await accumulate_task
 
 
             except Exception as e:
             except Exception as e:
                 self.finalize(exception=e)
                 self.finalize(exception=e)
@@ -239,8 +276,7 @@ class AllReduceRunner(ServicerBase):
 
 
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
         error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
         error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-        # Coroutines are lazy, so we take the first item to start the couroutine's execution
-        await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+        await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
 
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 1 - 1
hivemind/averaging/averager.py

@@ -609,7 +609,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     logger.info(f"Downloading parameters from peer {peer}")
                     logger.info(f"Downloading parameters from peer {peer}")
                     try:
                     try:
                         stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
                         stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
-                        stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
+                        stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
                         current_tensor_parts, tensors = [], []
 
 
                         async for message in aiter_with_timeout(stream, timeout=self.request_timeout):
                         async for message in aiter_with_timeout(stream, timeout=self.request_timeout):

+ 1 - 1
hivemind/averaging/matchmaking.py

@@ -180,7 +180,7 @@ class Matchmaking:
             async with self.lock_request_join_group:
             async with self.lock_request_join_group:
                 leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
                 leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
 
 
-                stream = leader_stub.rpc_join_group(
+                stream = await leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
                     averaging_pb2.JoinRequest(
                         schema_hash=self.schema_hash,
                         schema_hash=self.schema_hash,
                         expiration=expiration_time,
                         expiration=expiration_time,

+ 1 - 1
hivemind/averaging/partition.py

@@ -13,7 +13,7 @@ from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor
 from hivemind.utils.asyncio import amap_in_executor
 
 
 T = TypeVar("T")
 T = TypeVar("T")
-DEFAULT_PART_SIZE_BYTES = 2 ** 16
+DEFAULT_PART_SIZE_BYTES = 2 ** 19
 
 
 
 
 class TensorPartContainer:
 class TensorPartContainer:

+ 10 - 0
hivemind/optim/collaborative.py

@@ -210,6 +210,16 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.reset_accumulated_grads_()
             self.reset_accumulated_grads_()
             self.update_scheduler()
             self.update_scheduler()
 
 
+    def state_dict(self) -> dict:
+        state_dict = super().state_dict()
+        state_dict["state"]["collaborative_step"] = self.local_step
+        return state_dict
+
+    def load_state_dict(self, state_dict: dict):
+        if "collaborative_step" in state_dict["state"]:
+            self.averager.local_step = state_dict["state"].pop("collaborative_step")
+        return super().load_state_dict(state_dict)
+
     def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
     def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
         """
         """
         Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
         Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters

+ 20 - 17
hivemind/p2p/p2p_daemon.py

@@ -386,22 +386,25 @@ class P2P:
                 await P2P.send_protobuf(request, writer)
                 await P2P.send_protobuf(request, writer)
             await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
             await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
 
 
-        with closing(writer):
-            writing_task = asyncio.create_task(_write_to_stream())
-            try:
-                while True:
-                    try:
-                        response, err = await P2P.receive_protobuf(output_protobuf_type, reader)
-                    except asyncio.IncompleteReadError:  # Connection is closed
-                        break
+        async def _read_from_stream() -> AsyncIterator[Message]:
+            with closing(writer):
+                try:
+                    while True:
+                        try:
+                            response, err = await P2P.receive_protobuf(output_protobuf_type, reader)
+                        except asyncio.IncompleteReadError:  # Connection is closed
+                            break
 
 
-                    if err is not None:
-                        raise P2PHandlerError(f"Failed to call handler `{name}` at {peer_id}: {err.message}")
-                    yield response
+                        if err is not None:
+                            raise P2PHandlerError(f"Failed to call handler `{name}` at {peer_id}: {err.message}")
+                        yield response
+
+                    await writing_task
+                finally:
+                    writing_task.cancel()
 
 
-                await writing_task
-            finally:
-                writing_task.cancel()
+        writing_task = asyncio.create_task(_write_to_stream())
+        return _read_from_stream()
 
 
     async def add_protobuf_handler(
     async def add_protobuf_handler(
         self,
         self,
@@ -476,7 +479,7 @@ class P2P:
         if not isinstance(input, AsyncIterableABC):
         if not isinstance(input, AsyncIterableABC):
             return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
             return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
 
 
-        responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
+        responses = await self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
         return await asingle(responses)
         return await asingle(responses)
 
 
     async def _call_unary_protobuf_handler(
     async def _call_unary_protobuf_handler(
@@ -490,7 +493,7 @@ class P2P:
         response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
         response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
         return output_protobuf_type.FromString(response)
         return output_protobuf_type.FromString(response)
 
 
-    def iterate_protobuf_handler(
+    async def iterate_protobuf_handler(
         self,
         self,
         peer_id: PeerID,
         peer_id: PeerID,
         name: str,
         name: str,
@@ -498,7 +501,7 @@ class P2P:
         output_protobuf_type: Type[Message],
         output_protobuf_type: Type[Message],
     ) -> TOutputStream:
     ) -> TOutputStream:
         requests = input if isinstance(input, AsyncIterableABC) else as_aiter(input)
         requests = input if isinstance(input, AsyncIterableABC) else as_aiter(input)
-        return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+        return await self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
 
 
     def _start_listening(self) -> None:
     def _start_listening(self) -> None:
         async def listen() -> None:
         async def listen() -> None:

+ 9 - 26
hivemind/p2p/servicer.py

@@ -86,38 +86,21 @@ class ServicerBase:
     @classmethod
     @classmethod
     def _make_rpc_caller(cls, handler: RPCHandler):
     def _make_rpc_caller(cls, handler: RPCHandler):
         input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
         input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
+        output_type = AsyncIterator[handler.response_type] if handler.stream_output else handler.response_type
 
 
         # This method will be added to a new Stub type (a subclass of StubBase)
         # This method will be added to a new Stub type (a subclass of StubBase)
-        if handler.stream_output:
-
-            def caller(
-                self: StubBase, input: input_type, timeout: None = None
-            ) -> AsyncIterator[handler.response_type]:
-                if timeout is not None:
-                    raise ValueError("Timeouts for handlers returning streams are not supported")
-
-                return self._p2p.iterate_protobuf_handler(
-                    self._peer,
-                    cls._get_handle_name(self._namespace, handler.method_name),
-                    input,
-                    handler.response_type,
-                )
-
-        else:
-
-            async def caller(
-                self: StubBase, input: input_type, timeout: Optional[float] = None
-            ) -> handler.response_type:
+        async def caller(self: StubBase, input: input_type, timeout: Optional[float] = None) -> output_type:
+            handle_name = cls._get_handle_name(self._namespace, handler.method_name)
+            if not handler.stream_output:
                 return await asyncio.wait_for(
                 return await asyncio.wait_for(
-                    self._p2p.call_protobuf_handler(
-                        self._peer,
-                        cls._get_handle_name(self._namespace, handler.method_name),
-                        input,
-                        handler.response_type,
-                    ),
+                    self._p2p.call_protobuf_handler(self._peer, handle_name, input, handler.response_type),
                     timeout=timeout,
                     timeout=timeout,
                 )
                 )
 
 
+            if timeout is not None:
+                raise ValueError("Timeouts for handlers returning streams are not supported")
+            return await self._p2p.iterate_protobuf_handler(self._peer, handle_name, input, handler.response_type)
+
         caller.__name__ = handler.method_name
         caller.__name__ = handler.method_name
         return caller
         return caller
 
 

+ 9 - 0
hivemind/utils/asyncio.py

@@ -138,3 +138,12 @@ async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: float) -> Asyn
             yield await asyncio.wait_for(iterator.__anext__(), timeout=timeout)
             yield await asyncio.wait_for(iterator.__anext__(), timeout=timeout)
         except StopAsyncIteration:
         except StopAsyncIteration:
             break
             break
+
+
+async def attach_event_on_finished(iterable: AsyncIterable[T], event: asyncio.Event()) -> AsyncIterator[T]:
+    """Iterate over an async iterable and set an event when the iteration has stopped, failed or terminated"""
+    try:
+        async for item in iterable:
+            yield item
+    finally:
+        event.set()

+ 5 - 3
tests/test_p2p_servicer.py

@@ -68,8 +68,9 @@ async def test_unary_stream(server_client):
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
     stub = ExampleServicer.get_stub(client, server.peer_id)
     stub = ExampleServicer.get_stub(client, server.peer_id)
 
 
+    stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
     i = 0
     i = 0
-    async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
+    async for item in stream:
         assert item == test_pb2.TestResponse(number=i)
         assert item == test_pb2.TestResponse(number=i)
         i += 1
         i += 1
     assert i == 10
     assert i == 10
@@ -94,8 +95,9 @@ async def test_stream_stream(server_client):
         for i in range(10):
         for i in range(10):
             yield test_pb2.TestRequest(number=i)
             yield test_pb2.TestRequest(number=i)
 
 
+    stream = await stub.rpc_powers(generate_requests())
     i = 0
     i = 0
-    async for item in stub.rpc_powers(generate_requests()):
+    async for item in stream:
         if i % 2 == 0:
         if i % 2 == 0:
             assert item == test_pb2.TestResponse(number=(i // 2) ** 2)
             assert item == test_pb2.TestResponse(number=(i // 2) ** 2)
         else:
         else:
@@ -140,7 +142,7 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
         writer.close()
         writer.close()
     elif cancel_reason == "close_generator":
     elif cancel_reason == "close_generator":
         stub = ExampleServicer.get_stub(client, server.peer_id)
         stub = ExampleServicer.get_stub(client, server.peer_id)
-        iter = stub.rpc_wait(test_pb2.TestRequest(number=10))
+        iter = await stub.rpc_wait(test_pb2.TestRequest(number=10))
 
 
         assert await anext(iter) == test_pb2.TestResponse(number=11)
         assert await anext(iter) == test_pb2.TestResponse(number=11)
         await asyncio.sleep(0.25)
         await asyncio.sleep(0.25)

+ 13 - 0
tests/test_util_modules.py

@@ -23,6 +23,7 @@ from hivemind.utils.asyncio import (
     anext,
     anext,
     as_aiter,
     as_aiter,
     asingle,
     asingle,
+    attach_event_on_finished,
     azip,
     azip,
     cancel_and_wait,
     cancel_and_wait,
 )
 )
@@ -490,6 +491,18 @@ async def test_asyncio_utils():
 
 
     assert num_steps == 2
     assert num_steps == 2
 
 
+    event = asyncio.Event()
+    async for i in attach_event_on_finished(iterate_with_delays([0, 0, 0, 0, 0]), event):
+        assert not event.is_set()
+    assert event.is_set()
+
+    event = asyncio.Event()
+    sleepy_aiter = iterate_with_delays([0.1, 0.1, 0.3, 0.1, 0.1])
+    with pytest.raises(asyncio.TimeoutError):
+        async for _ in attach_event_on_finished(aiter_with_timeout(sleepy_aiter, timeout=0.2), event):
+            assert not event.is_set()
+    assert event.is_set()
+
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_cancel_and_wait():
 async def test_cancel_and_wait():