Преглед изворни кода

Fix minor asyncio issues in averager (#356)

1. Make averager not fail during reporting an error if the recipient closes the connection before receiving it (e.g. if it has crashed itself).
2. Support asyncio cancellation and shutdown in _load_state_from_peers, that is don't suppress asyncio.CancelledError and GeneratorExit.
3. Remove excess calls of __aiter__()/__anext__() (outside hivemind.utils.asyncio, we can always manage with anext(), so let's not scare the code readers with these magic methods).

Co-authored-by: justheuristic <justheuristic@gmail.com>
Alexander Borzunov пре 4 година
родитељ
комит
f97b742508

+ 2 - 2
hivemind/averaging/allreduce.py

@@ -231,8 +231,8 @@ class AllReduceRunner(ServicerBase):
 
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
         error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-        # In case of reporting the error, we expect the response stream to contain exactly one item
-        await asingle(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
+        # Coroutines are lazy, so we take the first item to start the couroutine's execution
+        await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 1 - 1
hivemind/averaging/averager.py

@@ -593,7 +593,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         future.set_result((metadata, tensors))
                         self.last_updated = get_dht_time()
                         return
-                    except BaseException as e:
+                    except Exception as e:
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
 
         finally:

+ 1 - 1
hivemind/averaging/matchmaking.py

@@ -189,7 +189,7 @@ class Matchmaking:
                         gather=self.data_for_gather,
                         group_key=self.group_key_manager.current_key,
                     )
-                ).__aiter__()
+                )
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
                 if message.code == averaging_pb2.ACCEPTED:

+ 9 - 3
hivemind/p2p/p2p_daemon.py

@@ -311,11 +311,17 @@ class P2P:
             async def _process_stream() -> None:
                 try:
                     async for response in handler(_read_stream(), context):
-                        await P2P.send_protobuf(response, writer)
+                        try:
+                            await P2P.send_protobuf(response, writer)
+                        except Exception:
+                            # The connection is unexpectedly closed by the caller or broken.
+                            # The loglevel is DEBUG since the actual error will be reported on the caller
+                            logger.debug("Exception while sending response:", exc_info=True)
+                            break
                 except Exception as e:
-                    logger.warning("Exception while processing stream and sending responses:", exc_info=True)
-                    # Sometimes `e` is a connection error, so we won't be able to report the error to the caller
+                    logger.warning("Handler failed with the exception:", exc_info=True)
                     with suppress(Exception):
+                        # Sometimes `e` is a connection error, so it is okay if we fail to report `e` to the caller
                         await P2P.send_protobuf(RPCError(message=str(e)), writer)
 
             with closing(writer):

+ 8 - 1
hivemind/utils/asyncio.py

@@ -59,7 +59,7 @@ async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]
 
 
 async def asingle(aiter: AsyncIterable[T]) -> T:
-    """If ``aiter`` has exactly one item, returns this item. Otherwise, raises `ValueError`."""
+    """If ``aiter`` has exactly one item, returns this item. Otherwise, raises ``ValueError``."""
     count = 0
     async for item in aiter:
         count += 1
@@ -70,6 +70,13 @@ async def asingle(aiter: AsyncIterable[T]) -> T:
     return item
 
 
+async def afirst(aiter: AsyncIterable[T], default: Optional[T] = None) -> Optional[T]:
+    """Returns the first item of ``aiter`` or ``default`` if ``aiter`` is empty."""
+    async for item in aiter:
+        return item
+    return default
+
+
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
         await awaitable

+ 3 - 2
tests/test_p2p_servicer.py

@@ -5,6 +5,7 @@ import pytest
 
 from hivemind.p2p import P2P, P2PContext, ServicerBase
 from hivemind.proto import test_pb2
+from hivemind.utils.asyncio import anext
 
 
 @pytest.fixture
@@ -139,9 +140,9 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
         writer.close()
     elif cancel_reason == "close_generator":
         stub = ExampleServicer.get_stub(client, server.peer_id)
-        iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
+        iter = stub.rpc_wait(test_pb2.TestRequest(number=10))
 
-        assert await iter.__anext__() == test_pb2.TestResponse(number=11)
+        assert await anext(iter) == test_pb2.TestResponse(number=11)
         await asyncio.sleep(0.25)
 
         await iter.aclose()

+ 12 - 1
tests/test_util_modules.py

@@ -13,7 +13,7 @@ from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
-from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, azip
+from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext, asingle, azip
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 
@@ -498,3 +498,14 @@ async def test_asyncio_utils():
         await anext(iterator)
 
     assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
+
+    assert await asingle(aiter(1)) == 1
+    with pytest.raises(ValueError):
+        await asingle(aiter())
+    with pytest.raises(ValueError):
+        await asingle(aiter(1, 2, 3))
+
+    assert await afirst(aiter(1)) == 1
+    assert await afirst(aiter()) is None
+    assert await afirst(aiter(), -1) == -1
+    assert await afirst(aiter(1, 2, 3)) == 1