|
@@ -22,7 +22,7 @@ from hivemind.utils.asyncio import (
|
|
|
anext,
|
|
|
asingle,
|
|
|
azip,
|
|
|
- cancel_and_wait,
|
|
|
+ cancel_and_wait, aiter_with_timeout,
|
|
|
)
|
|
|
from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
from hivemind.utils.mpfuture import InvalidStateError
|
|
@@ -520,6 +520,22 @@ async def test_asyncio_utils():
|
|
|
assert await afirst(as_aiter(), -1) == -1
|
|
|
assert await afirst(as_aiter(1, 2, 3)) == 1
|
|
|
|
|
|
+ async def iterate(delays):
|
|
|
+ for i, delay in enumerate(delays):
|
|
|
+ await asyncio.sleep(delay)
|
|
|
+ yield i
|
|
|
+
|
|
|
+ async for _ in aiter_with_timeout(iterate([0.1] * 5), timeout=0.2):
|
|
|
+ pass
|
|
|
+
|
|
|
+ sleepy_aiter = iterate([0.1, 0.1, 0.3, 0.1, 0.1])
|
|
|
+ num_steps = 0
|
|
|
+ with pytest.raises(asyncio.TimeoutError):
|
|
|
+ async for _ in aiter_with_timeout(sleepy_aiter, timeout=0.2):
|
|
|
+ num_steps += 1
|
|
|
+
|
|
|
+ assert num_steps == 2
|
|
|
+
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_cancel_and_wait():
|