|
@@ -13,7 +13,7 @@ from hivemind.proto.dht_pb2_grpc import DHTStub
|
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
|
|
|
from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
|
|
|
-from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, azip
|
|
|
+from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext, asingle, azip
|
|
|
from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
from hivemind.utils.mpfuture import InvalidStateError
|
|
|
|
|
@@ -498,3 +498,14 @@ async def test_asyncio_utils():
|
|
|
await anext(iterator)
|
|
|
|
|
|
assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
|
|
|
+
|
|
|
+ assert await asingle(aiter(1)) == 1
|
|
|
+ with pytest.raises(ValueError):
|
|
|
+ await asingle(aiter())
|
|
|
+ with pytest.raises(ValueError):
|
|
|
+ await asingle(aiter(1, 2, 3))
|
|
|
+
|
|
|
+ assert await afirst(aiter(1)) == 1
|
|
|
+ assert await afirst(aiter()) is None
|
|
|
+ assert await afirst(aiter(), -1) == -1
|
|
|
+ assert await afirst(aiter(1, 2, 3)) == 1
|