Alexander Borzunov 4 jaren geleden
bovenliggende
commit
b444676e51

+ 3 - 3
hivemind/averaging/allreduce.py

@@ -8,7 +8,7 @@ from hivemind.averaging.partition import AllreduceException, TensorPartContainer
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
-from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext
+from hivemind.utils.asyncio import achain, aenumerate, afirst, as_aiter, amap_in_executor, anext
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 # flavour types
@@ -193,7 +193,7 @@ class AllReduceRunner(ServicerBase):
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
                 sender_index = self.sender_peer_ids.index(context.remote_id)
-                async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
+                async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
                     yield msg
 
             except Exception as e:
@@ -232,7 +232,7 @@ class AllReduceRunner(ServicerBase):
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
         error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
         # Coroutines are lazy, so we take the first item to start the couroutine's execution
-        await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
+        await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 2 - 2
hivemind/averaging/averager.py

@@ -25,7 +25,7 @@ from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2, runtime_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import achain, aiter, aiter_with_timeout, anext, switch_to_uvloop
+from hivemind.utils.asyncio import achain, as_aiter, aiter_with_timeout, anext, switch_to_uvloop
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
@@ -496,7 +496,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
             return
 
-        async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
+        async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
             yield message
 
     async def _declare_for_download_periodically(self):

+ 2 - 2
hivemind/p2p/p2p_daemon.py

@@ -15,7 +15,7 @@ import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.proto.p2pd_pb2 import RPCError
-from hivemind.utils.asyncio import aiter, asingle
+from hivemind.utils.asyncio import as_aiter, asingle
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
@@ -480,7 +480,7 @@ class P2P:
         input: Union[TInputProtobuf, TInputStream],
         output_protobuf_type: Type[Message],
     ) -> TOutputStream:
-        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
+        requests = input if isinstance(input, AsyncIterableABC) else as_aiter(input)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
 
     def _start_listening(self) -> None:

+ 1 - 1
hivemind/utils/asyncio.py

@@ -28,7 +28,7 @@ async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
     return await aiter.__anext__()
 
 
-async def aiter(*args: T) -> AsyncIterator[T]:
+async def as_aiter(*args: T) -> AsyncIterator[T]:
     """create an asynchronous iterator from a sequence of values"""
     for arg in args:
         yield arg

+ 13 - 13
tests/test_util_modules.py

@@ -17,7 +17,7 @@ from hivemind.utils.asyncio import (
     achain,
     aenumerate,
     afirst,
-    aiter,
+    as_aiter,
     amap_in_executor,
     anext,
     asingle,
@@ -478,20 +478,20 @@ def test_generic_data_classes():
 
 @pytest.mark.asyncio
 async def test_asyncio_utils():
-    res = [i async for i, item in aenumerate(aiter("a", "b", "c"))]
+    res = [i async for i, item in aenumerate(as_aiter("a", "b", "c"))]
     assert res == list(range(len(res)))
 
     num_steps = 0
-    async for elem in amap_in_executor(lambda x: x ** 2, aiter(*range(100)), max_prefetch=5):
+    async for elem in amap_in_executor(lambda x: x ** 2, as_aiter(*range(100)), max_prefetch=5):
         assert elem == num_steps ** 2
         num_steps += 1
     assert num_steps == 100
 
-    ours = [elem async for elem in amap_in_executor(max, aiter(*range(7)), aiter(*range(-50, 50, 10)), max_prefetch=1)]
+    ours = [elem async for elem in amap_in_executor(max, as_aiter(*range(7)), as_aiter(*range(-50, 50, 10)), max_prefetch=1)]
     ref = list(map(max, range(7), range(-50, 50, 10)))
     assert ours == ref
 
-    ours = [row async for row in azip(aiter("a", "b", "c"), aiter(1, 2, 3))]
+    ours = [row async for row in azip(as_aiter("a", "b", "c"), as_aiter(1, 2, 3))]
     ref = list(zip(["a", "b", "c"], [1, 2, 3]))
     assert ours == ref
 
@@ -507,18 +507,18 @@ async def test_asyncio_utils():
     with pytest.raises(StopAsyncIteration):
         await anext(iterator)
 
-    assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
+    assert [item async for item in achain(_aiterate(), as_aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
 
-    assert await asingle(aiter(1)) == 1
+    assert await asingle(as_aiter(1)) == 1
     with pytest.raises(ValueError):
-        await asingle(aiter())
+        await asingle(as_aiter())
     with pytest.raises(ValueError):
-        await asingle(aiter(1, 2, 3))
+        await asingle(as_aiter(1, 2, 3))
 
-    assert await afirst(aiter(1)) == 1
-    assert await afirst(aiter()) is None
-    assert await afirst(aiter(), -1) == -1
-    assert await afirst(aiter(1, 2, 3)) == 1
+    assert await afirst(as_aiter(1)) == 1
+    assert await afirst(as_aiter()) is None
+    assert await afirst(as_aiter(), -1) == -1
+    assert await afirst(as_aiter(1, 2, 3)) == 1
 
 
 @pytest.mark.asyncio