import asyncio import concurrent.futures import multiprocessing as mp import random import time from concurrent.futures import ThreadPoolExecutor import numpy as np import pytest import torch import hivemind from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration from hivemind.utils.asyncio import ( achain, aenumerate, aiter_with_timeout, amap_in_executor, anext, as_aiter, asingle, attach_event_on_finished, azip, cancel_and_wait, enter_asynchronously, ) from hivemind.utils.mpfuture import InvalidStateError from hivemind.utils.performance_ema import PerformanceEMA @pytest.mark.forked def test_mpfuture_result(): future = hivemind.MPFuture() def _proc(future): with pytest.raises(RuntimeError): future.result() # only creator process can await result future.set_result(321) p = mp.Process(target=_proc, args=(future,)) p.start() p.join() assert future.result() == 321 assert future.exception() is None assert future.cancel() is False assert future.done() and not future.running() and not future.cancelled() future = hivemind.MPFuture() with pytest.raises(concurrent.futures.TimeoutError): future.result(timeout=1e-3) future.set_result(["abacaba", 123]) assert future.result() == ["abacaba", 123] @pytest.mark.forked def test_mpfuture_exception(): future = hivemind.MPFuture() with pytest.raises(concurrent.futures.TimeoutError): future.exception(timeout=1e-3) def _proc(future): future.set_exception(NotImplementedError()) p = mp.Process(target=_proc, args=(future,)) p.start() p.join() assert isinstance(future.exception(), NotImplementedError) with pytest.raises(NotImplementedError): future.result() assert future.cancel() is False assert future.done() and not future.running() and not future.cancelled() @pytest.mark.forked def test_mpfuture_cancel(): future = hivemind.MPFuture() assert not future.cancelled() future.cancel() evt = mp.Event() def _proc(): with pytest.raises(concurrent.futures.CancelledError): future.result() with pytest.raises(concurrent.futures.CancelledError): future.exception() with pytest.raises(InvalidStateError): future.set_result(123) with pytest.raises(InvalidStateError): future.set_exception(NotImplementedError()) assert future.cancelled() and future.done() and not future.running() evt.set() p = mp.Process(target=_proc) p.start() p.join() assert evt.is_set() @pytest.mark.forked def test_mpfuture_status(): evt = mp.Event() future = hivemind.MPFuture() def _proc1(future): assert future.set_running_or_notify_cancel() is True evt.set() p = mp.Process(target=_proc1, args=(future,)) p.start() p.join() assert evt.is_set() evt.clear() assert future.running() and not future.done() and not future.cancelled() with pytest.raises(InvalidStateError): future.set_running_or_notify_cancel() future = hivemind.MPFuture() assert future.cancel() def _proc2(future): assert not future.running() and future.done() and future.cancelled() assert future.set_running_or_notify_cancel() is False evt.set() p = mp.Process(target=_proc2, args=(future,)) p.start() p.join() evt.set() future2 = hivemind.MPFuture() future2.cancel() assert future2.set_running_or_notify_cancel() is False @pytest.mark.asyncio async def test_await_mpfuture(): # await result from the same process, but a different coroutine f1, f2 = hivemind.MPFuture(), hivemind.MPFuture() async def wait_and_assign_async(): assert f2.set_running_or_notify_cancel() is True await asyncio.sleep(0.1) f1.set_result((123, "ololo")) f2.set_result((456, "pyshpysh")) asyncio.create_task(wait_and_assign_async()) assert (await asyncio.gather(f1, f2)) == [(123, "ololo"), (456, "pyshpysh")] # await result from separate processes f1, f2 = hivemind.MPFuture(), hivemind.MPFuture() def wait_and_assign(future, value): time.sleep(0.1 * random.random()) future.set_result(value) p1 = mp.Process(target=wait_and_assign, args=(f1, "abc")) p2 = mp.Process(target=wait_and_assign, args=(f2, "def")) for p in p1, p2: p.start() assert (await asyncio.gather(f1, f2)) == ["abc", "def"] for p in p1, p2: p.join() # await cancel f1, f2 = hivemind.MPFuture(), hivemind.MPFuture() def wait_and_cancel(): time.sleep(0.01) f2.set_result(123456) time.sleep(0.1) f1.cancel() p = mp.Process(target=wait_and_cancel) p.start() with pytest.raises(asyncio.CancelledError): # note: it is intended that MPFuture raises Cancel await asyncio.gather(f1, f2) p.join() # await exception f1, f2 = hivemind.MPFuture(), hivemind.MPFuture() def wait_and_raise(): time.sleep(0.01) f2.set_result(123456) time.sleep(0.1) f1.set_exception(ValueError("we messed up")) p = mp.Process(target=wait_and_raise) p.start() with pytest.raises(ValueError): # note: it is intended that MPFuture raises Cancel await asyncio.gather(f1, f2) p.join() @pytest.mark.forked def test_mpfuture_bidirectional(): evt = mp.Event() future_from_main = hivemind.MPFuture() def _future_creator(): future_from_fork = hivemind.MPFuture() future_from_main.set_result(("abc", future_from_fork)) if future_from_fork.result() == ["we", "need", "to", "go", "deeper"]: evt.set() p = mp.Process(target=_future_creator) p.start() out = future_from_main.result() assert isinstance(out[1], hivemind.MPFuture) out[1].set_result(["we", "need", "to", "go", "deeper"]) p.join() assert evt.is_set() @pytest.mark.forked def test_mpfuture_done_callback(): receiver, sender = mp.Pipe(duplex=False) events = [mp.Event() for _ in range(6)] def _future_creator(): future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture() def _check_result_and_set(future): assert future.done() assert future.result() == 123 events[0].set() future1.add_done_callback(_check_result_and_set) future1.add_done_callback(lambda future: events[1].set()) future2.add_done_callback(lambda future: events[2].set()) future3.add_done_callback(lambda future: events[3].set()) sender.send((future1, future2)) future2.cancel() # trigger future2 callback from the same process events[0].wait() future1.add_done_callback( lambda future: events[4].set() ) # schedule callback after future1 is already finished events[5].wait() p = mp.Process(target=_future_creator) p.start() future1, future2 = receiver.recv() future1.set_result(123) with pytest.raises(RuntimeError): future1.add_done_callback(lambda future: (1, 2, 3)) assert future1.done() and not future1.cancelled() assert future2.done() and future2.cancelled() for i in 0, 1, 4: events[i].wait(1) assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set() assert not events[3].is_set() events[5].set() p.join() @pytest.mark.forked def test_many_futures(): evt = mp.Event() receiver, sender = mp.Pipe() main_futures = [hivemind.MPFuture() for _ in range(1000)] assert len(hivemind.MPFuture._active_futures) == 1000 def _run_peer(): fork_futures = [hivemind.MPFuture() for _ in range(500)] assert len(hivemind.MPFuture._active_futures) == 500 for i, future in enumerate(random.sample(main_futures, 300)): if random.random() < 0.5: future.set_result(i) else: future.set_exception(ValueError(f"{i}")) sender.send(fork_futures[:-100]) for future in fork_futures[-100:]: future.cancel() evt.wait() assert len(hivemind.MPFuture._active_futures) == 200 for future in fork_futures: if not future.done(): future.set_result(123) assert len(hivemind.MPFuture._active_futures) == 0 p = mp.Process(target=_run_peer) p.start() some_fork_futures = receiver.recv() time.sleep(0.1) # giving enough time for the futures to be destroyed assert len(hivemind.MPFuture._active_futures) == 700 for future in some_fork_futures: future.set_running_or_notify_cancel() for future in random.sample(some_fork_futures, 200): future.set_result(321) evt.set() for future in main_futures: future.cancel() time.sleep(0.1) # giving enough time for the futures to be destroyed assert len(hivemind.MPFuture._active_futures) == 0 p.join() def test_serialize_tuple(): test_pairs = ( ((1, 2, 3), [1, 2, 3]), (("1", False, 0), ["1", False, 0]), (("1", False, 0), ("1", 0, 0)), (("1", b"qq", (2, 5, "0")), ["1", b"qq", (2, 5, "0")]), ) for first, second in test_pairs: assert MSGPackSerializer.loads(MSGPackSerializer.dumps(first)) == first assert MSGPackSerializer.loads(MSGPackSerializer.dumps(second)) == second assert MSGPackSerializer.dumps(first) != MSGPackSerializer.dumps(second) def test_split_parts(): tensor = torch.randn(910, 512) serialized_tensor_part = serialize_torch_tensor(tensor, allow_inplace=False) chunks1 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 16384)) assert len(chunks1) == int(np.ceil(tensor.numel() * tensor.element_size() / 16384)) chunks2 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000)) assert len(chunks2) == int(np.ceil(tensor.numel() * tensor.element_size() / 10_000)) chunks3 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10**9)) assert len(chunks3) == 1 compressed_tensor_part = serialize_torch_tensor(tensor, CompressionType.FLOAT16, allow_inplace=False) chunks4 = list(hivemind.utils.split_for_streaming(compressed_tensor_part, 16384)) assert len(chunks4) == int(np.ceil(tensor.numel() * 2 / 16384)) combined1 = hivemind.utils.combine_from_streaming(chunks1) combined2 = hivemind.utils.combine_from_streaming(iter(chunks2)) combined3 = hivemind.utils.combine_from_streaming(chunks3) combined4 = hivemind.utils.combine_from_streaming(chunks4) for combined in combined1, combined2, combined3: assert torch.allclose(tensor, deserialize_torch_tensor(combined), rtol=1e-5, atol=1e-8) assert torch.allclose(tensor, deserialize_torch_tensor(combined4), rtol=1e-3, atol=1e-3) combined_incomplete = hivemind.utils.combine_from_streaming(chunks4[:5]) combined_incomplete2 = hivemind.utils.combine_from_streaming(chunks4[:1]) combined_incomplete3 = hivemind.utils.combine_from_streaming(chunks4[:-1]) for combined in combined_incomplete, combined_incomplete2, combined_incomplete3: with pytest.raises(RuntimeError): deserialize_torch_tensor(combined) # note: we rely on this being RuntimeError in hivemind.averaging.allreduce.AllReduceRunner def test_generic_data_classes(): value_with_exp = ValueWithExpiration(value="string_value", expiration_time=DHTExpiration(10)) assert value_with_exp.value == "string_value" and value_with_exp.expiration_time == DHTExpiration(10) heap_entry = HeapEntry(expiration_time=DHTExpiration(10), key="string_value") assert heap_entry.key == "string_value" and heap_entry.expiration_time == DHTExpiration(10) sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)]) sorted_heap_entries = sorted([HeapEntry(DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]]) assert all([entry.expiration_time == value for entry, value in zip(sorted_heap_entries, sorted_expirations)]) @pytest.mark.asyncio async def test_asyncio_utils(): res = [i async for i, item in aenumerate(as_aiter("a", "b", "c"))] assert res == list(range(len(res))) num_steps = 0 async for elem in amap_in_executor(lambda x: x**2, as_aiter(*range(100)), max_prefetch=5): assert elem == num_steps**2 num_steps += 1 assert num_steps == 100 ours = [ elem async for elem in amap_in_executor(max, as_aiter(*range(7)), as_aiter(*range(-50, 50, 10)), max_prefetch=1) ] ref = list(map(max, range(7), range(-50, 50, 10))) assert ours == ref ours = [row async for row in azip(as_aiter("a", "b", "c"), as_aiter(1, 2, 3))] ref = list(zip(["a", "b", "c"], [1, 2, 3])) assert ours == ref async def _aiterate(): yield "foo" yield "bar" yield "baz" iterator = _aiterate() assert (await anext(iterator)) == "foo" tail = [item async for item in iterator] assert tail == ["bar", "baz"] with pytest.raises(StopAsyncIteration): await anext(iterator) assert [item async for item in achain(_aiterate(), as_aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5)) assert await asingle(as_aiter(1)) == 1 with pytest.raises(ValueError): await asingle(as_aiter()) with pytest.raises(ValueError): await asingle(as_aiter(1, 2, 3)) async def iterate_with_delays(delays): for i, delay in enumerate(delays): await asyncio.sleep(delay) yield i async for _ in aiter_with_timeout(iterate_with_delays([0.1] * 5), timeout=0.2): pass sleepy_aiter = iterate_with_delays([0.1, 0.1, 0.3, 0.1, 0.1]) num_steps = 0 with pytest.raises(asyncio.TimeoutError): async for _ in aiter_with_timeout(sleepy_aiter, timeout=0.2): num_steps += 1 assert num_steps == 2 event = asyncio.Event() async for i in attach_event_on_finished(iterate_with_delays([0, 0, 0, 0, 0]), event): assert not event.is_set() assert event.is_set() event = asyncio.Event() sleepy_aiter = iterate_with_delays([0.1, 0.1, 0.3, 0.1, 0.1]) with pytest.raises(asyncio.TimeoutError): async for _ in attach_event_on_finished(aiter_with_timeout(sleepy_aiter, timeout=0.2), event): assert not event.is_set() assert event.is_set() @pytest.mark.asyncio async def test_cancel_and_wait(): finished_gracefully = False async def coro_with_finalizer(): nonlocal finished_gracefully try: await asyncio.Event().wait() except asyncio.CancelledError: await asyncio.sleep(0.05) finished_gracefully = True raise task = asyncio.create_task(coro_with_finalizer()) await asyncio.sleep(0.05) assert await cancel_and_wait(task) assert finished_gracefully async def coro_with_result(): return 777 async def coro_with_error(): raise ValueError("error") task_with_result = asyncio.create_task(coro_with_result()) task_with_error = asyncio.create_task(coro_with_error()) await asyncio.sleep(0.05) assert not await cancel_and_wait(task_with_result) assert not await cancel_and_wait(task_with_error) @pytest.mark.asyncio async def test_async_context(): lock = mp.Lock() async def coro1(): async with enter_asynchronously(lock): await asyncio.sleep(0.2) async def coro2(): await asyncio.sleep(0.1) async with enter_asynchronously(lock): await asyncio.sleep(0.1) await asyncio.wait_for(asyncio.gather(coro1(), coro2()), timeout=0.5) # running this without enter_asynchronously would deadlock the event loop def test_batch_tensor_descriptor_msgpack(): tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7)) tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr)) assert ( tensor_descr.size == tensor_descr_roundtrip.size and tensor_descr.dtype == tensor_descr_roundtrip.dtype and tensor_descr.layout == tensor_descr_roundtrip.layout and tensor_descr.device == tensor_descr_roundtrip.device and tensor_descr.requires_grad == tensor_descr_roundtrip.requires_grad and tensor_descr.pin_memory == tensor_descr.pin_memory and tensor_descr.compression == tensor_descr.compression ) @pytest.mark.parametrize("max_workers", [1, 2, 10]) def test_performance_ema_threadsafe( max_workers: int, interval: float = 0.01, num_updates: int = 100, alpha: float = 0.05, bias_power: float = 0.7, tolerance: float = 0.05, ): def run_task(ema): task_size = random.randint(1, 4) with ema.update_threadsafe(task_size): time.sleep(task_size * interval * (0.9 + 0.2 * random.random())) return task_size with ThreadPoolExecutor(max_workers) as pool: ema = PerformanceEMA(alpha=alpha) start_time = time.perf_counter() futures = [pool.submit(run_task, ema) for i in range(num_updates)] total_size = sum(future.result() for future in futures) end_time = time.perf_counter() target = total_size / (end_time - start_time) assert ema.samples_per_second >= (1 - tolerance) * target * max_workers ** (bias_power - 1) assert ema.samples_per_second <= (1 + tolerance) * target