test_util_modules.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. import asyncio
  2. import concurrent.futures
  3. import multiprocessing as mp
  4. import random
  5. import time
  6. from concurrent.futures import ThreadPoolExecutor
  7. import numpy as np
  8. import pytest
  9. import torch
  10. import hivemind
  11. from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
  12. from hivemind.proto.dht_pb2_grpc import DHTStub
  13. from hivemind.proto.runtime_pb2 import CompressionType
  14. from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
  15. from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
  16. from hivemind.utils.asyncio import (
  17. achain,
  18. aenumerate,
  19. afirst,
  20. aiter_with_timeout,
  21. amap_in_executor,
  22. anext,
  23. as_aiter,
  24. asingle,
  25. attach_event_on_finished,
  26. azip,
  27. cancel_and_wait,
  28. enter_asynchronously,
  29. )
  30. from hivemind.utils.mpfuture import InvalidStateError
  31. from hivemind.utils.performance_ema import PerformanceEMA
  32. @pytest.mark.forked
  33. def test_mpfuture_result():
  34. future = hivemind.MPFuture()
  35. def _proc(future):
  36. with pytest.raises(RuntimeError):
  37. future.result() # only creator process can await result
  38. future.set_result(321)
  39. p = mp.Process(target=_proc, args=(future,))
  40. p.start()
  41. p.join()
  42. assert future.result() == 321
  43. assert future.exception() is None
  44. assert future.cancel() is False
  45. assert future.done() and not future.running() and not future.cancelled()
  46. future = hivemind.MPFuture()
  47. with pytest.raises(concurrent.futures.TimeoutError):
  48. future.result(timeout=1e-3)
  49. future.set_result(["abacaba", 123])
  50. assert future.result() == ["abacaba", 123]
  51. @pytest.mark.forked
  52. def test_mpfuture_exception():
  53. future = hivemind.MPFuture()
  54. with pytest.raises(concurrent.futures.TimeoutError):
  55. future.exception(timeout=1e-3)
  56. def _proc(future):
  57. future.set_exception(NotImplementedError())
  58. p = mp.Process(target=_proc, args=(future,))
  59. p.start()
  60. p.join()
  61. assert isinstance(future.exception(), NotImplementedError)
  62. with pytest.raises(NotImplementedError):
  63. future.result()
  64. assert future.cancel() is False
  65. assert future.done() and not future.running() and not future.cancelled()
  66. @pytest.mark.forked
  67. def test_mpfuture_cancel():
  68. future = hivemind.MPFuture()
  69. assert not future.cancelled()
  70. future.cancel()
  71. evt = mp.Event()
  72. def _proc():
  73. with pytest.raises(concurrent.futures.CancelledError):
  74. future.result()
  75. with pytest.raises(concurrent.futures.CancelledError):
  76. future.exception()
  77. with pytest.raises(InvalidStateError):
  78. future.set_result(123)
  79. with pytest.raises(InvalidStateError):
  80. future.set_exception(NotImplementedError())
  81. assert future.cancelled() and future.done() and not future.running()
  82. evt.set()
  83. p = mp.Process(target=_proc)
  84. p.start()
  85. p.join()
  86. assert evt.is_set()
  87. @pytest.mark.forked
  88. def test_mpfuture_status():
  89. evt = mp.Event()
  90. future = hivemind.MPFuture()
  91. def _proc1(future):
  92. assert future.set_running_or_notify_cancel() is True
  93. evt.set()
  94. p = mp.Process(target=_proc1, args=(future,))
  95. p.start()
  96. p.join()
  97. assert evt.is_set()
  98. evt.clear()
  99. assert future.running() and not future.done() and not future.cancelled()
  100. with pytest.raises(InvalidStateError):
  101. future.set_running_or_notify_cancel()
  102. future = hivemind.MPFuture()
  103. assert future.cancel()
  104. def _proc2(future):
  105. assert not future.running() and future.done() and future.cancelled()
  106. assert future.set_running_or_notify_cancel() is False
  107. evt.set()
  108. p = mp.Process(target=_proc2, args=(future,))
  109. p.start()
  110. p.join()
  111. evt.set()
  112. future2 = hivemind.MPFuture()
  113. future2.cancel()
  114. assert future2.set_running_or_notify_cancel() is False
  115. @pytest.mark.asyncio
  116. async def test_await_mpfuture():
  117. # await result from the same process, but a different coroutine
  118. f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
  119. async def wait_and_assign_async():
  120. assert f2.set_running_or_notify_cancel() is True
  121. await asyncio.sleep(0.1)
  122. f1.set_result((123, "ololo"))
  123. f2.set_result((456, "pyshpysh"))
  124. asyncio.create_task(wait_and_assign_async())
  125. assert (await asyncio.gather(f1, f2)) == [(123, "ololo"), (456, "pyshpysh")]
  126. # await result from separate processes
  127. f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
  128. def wait_and_assign(future, value):
  129. time.sleep(0.1 * random.random())
  130. future.set_result(value)
  131. p1 = mp.Process(target=wait_and_assign, args=(f1, "abc"))
  132. p2 = mp.Process(target=wait_and_assign, args=(f2, "def"))
  133. for p in p1, p2:
  134. p.start()
  135. assert (await asyncio.gather(f1, f2)) == ["abc", "def"]
  136. for p in p1, p2:
  137. p.join()
  138. # await cancel
  139. f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
  140. def wait_and_cancel():
  141. time.sleep(0.01)
  142. f2.set_result(123456)
  143. time.sleep(0.1)
  144. f1.cancel()
  145. p = mp.Process(target=wait_and_cancel)
  146. p.start()
  147. with pytest.raises(asyncio.CancelledError):
  148. # note: it is intended that MPFuture raises Cancel
  149. await asyncio.gather(f1, f2)
  150. p.join()
  151. # await exception
  152. f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
  153. def wait_and_raise():
  154. time.sleep(0.01)
  155. f2.set_result(123456)
  156. time.sleep(0.1)
  157. f1.set_exception(ValueError("we messed up"))
  158. p = mp.Process(target=wait_and_raise)
  159. p.start()
  160. with pytest.raises(ValueError):
  161. # note: it is intended that MPFuture raises Cancel
  162. await asyncio.gather(f1, f2)
  163. p.join()
  164. @pytest.mark.forked
  165. def test_mpfuture_bidirectional():
  166. evt = mp.Event()
  167. future_from_main = hivemind.MPFuture()
  168. def _future_creator():
  169. future_from_fork = hivemind.MPFuture()
  170. future_from_main.set_result(("abc", future_from_fork))
  171. if future_from_fork.result() == ["we", "need", "to", "go", "deeper"]:
  172. evt.set()
  173. p = mp.Process(target=_future_creator)
  174. p.start()
  175. out = future_from_main.result()
  176. assert isinstance(out[1], hivemind.MPFuture)
  177. out[1].set_result(["we", "need", "to", "go", "deeper"])
  178. p.join()
  179. assert evt.is_set()
  180. @pytest.mark.forked
  181. def test_mpfuture_done_callback():
  182. receiver, sender = mp.Pipe(duplex=False)
  183. events = [mp.Event() for _ in range(6)]
  184. def _future_creator():
  185. future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture()
  186. def _check_result_and_set(future):
  187. assert future.done()
  188. assert future.result() == 123
  189. events[0].set()
  190. future1.add_done_callback(_check_result_and_set)
  191. future1.add_done_callback(lambda future: events[1].set())
  192. future2.add_done_callback(lambda future: events[2].set())
  193. future3.add_done_callback(lambda future: events[3].set())
  194. sender.send((future1, future2))
  195. future2.cancel() # trigger future2 callback from the same process
  196. events[0].wait()
  197. future1.add_done_callback(
  198. lambda future: events[4].set()
  199. ) # schedule callback after future1 is already finished
  200. events[5].wait()
  201. p = mp.Process(target=_future_creator)
  202. p.start()
  203. future1, future2 = receiver.recv()
  204. future1.set_result(123)
  205. with pytest.raises(RuntimeError):
  206. future1.add_done_callback(lambda future: (1, 2, 3))
  207. assert future1.done() and not future1.cancelled()
  208. assert future2.done() and future2.cancelled()
  209. for i in 0, 1, 4:
  210. events[i].wait(1)
  211. assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
  212. assert not events[3].is_set()
  213. events[5].set()
  214. p.join()
  215. @pytest.mark.forked
  216. def test_many_futures():
  217. evt = mp.Event()
  218. receiver, sender = mp.Pipe()
  219. main_futures = [hivemind.MPFuture() for _ in range(1000)]
  220. assert len(hivemind.MPFuture._active_futures) == 1000
  221. def _run_peer():
  222. fork_futures = [hivemind.MPFuture() for _ in range(500)]
  223. assert len(hivemind.MPFuture._active_futures) == 500
  224. for i, future in enumerate(random.sample(main_futures, 300)):
  225. if random.random() < 0.5:
  226. future.set_result(i)
  227. else:
  228. future.set_exception(ValueError(f"{i}"))
  229. sender.send(fork_futures[:-100])
  230. for future in fork_futures[-100:]:
  231. future.cancel()
  232. evt.wait()
  233. assert len(hivemind.MPFuture._active_futures) == 200
  234. for future in fork_futures:
  235. if not future.done():
  236. future.set_result(123)
  237. assert len(hivemind.MPFuture._active_futures) == 0
  238. p = mp.Process(target=_run_peer)
  239. p.start()
  240. some_fork_futures = receiver.recv()
  241. time.sleep(0.1) # giving enough time for the futures to be destroyed
  242. assert len(hivemind.MPFuture._active_futures) == 700
  243. for future in some_fork_futures:
  244. future.set_running_or_notify_cancel()
  245. for future in random.sample(some_fork_futures, 200):
  246. future.set_result(321)
  247. evt.set()
  248. for future in main_futures:
  249. future.cancel()
  250. time.sleep(0.1) # giving enough time for the futures to be destroyed
  251. assert len(hivemind.MPFuture._active_futures) == 0
  252. p.join()
  253. @pytest.mark.forked
  254. @pytest.mark.asyncio
  255. async def test_channel_cache():
  256. hivemind.ChannelCache.MAXIMUM_CHANNELS = 3
  257. hivemind.ChannelCache.EVICTION_PERIOD_SECONDS = 0.1
  258. c1 = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
  259. c2 = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=True)
  260. c3 = hivemind.ChannelCache.get_stub("localhost:1338", DHTStub, aio=False)
  261. c3_again = hivemind.ChannelCache.get_stub("localhost:1338", DHTStub, aio=False)
  262. c1_again = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
  263. c4 = hivemind.ChannelCache.get_stub("localhost:1339", DHTStub, aio=True)
  264. c2_anew = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=True)
  265. c1_yetagain = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
  266. await asyncio.sleep(0.2)
  267. c1_anew = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=DHTStub)
  268. c1_anew_again = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=DHTStub)
  269. c1_otherstub = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=ConnectionHandlerStub)
  270. await asyncio.sleep(0.05)
  271. c1_otherstub_again = hivemind.ChannelCache.get_stub(
  272. target="localhost:1337", aio=False, stub_type=ConnectionHandlerStub
  273. )
  274. all_channels = [c1, c2, c3, c4, c3_again, c1_again, c2_anew, c1_yetagain, c1_anew, c1_anew_again, c1_otherstub]
  275. assert all(isinstance(c, DHTStub) for c in all_channels[:-1])
  276. assert isinstance(all_channels[-1], ConnectionHandlerStub)
  277. assert "aio" in repr(c2.rpc_find)
  278. assert "aio" not in repr(c1.rpc_find)
  279. duplicates = {
  280. (c1, c1_again),
  281. (c1, c1_yetagain),
  282. (c1_again, c1_yetagain),
  283. (c3, c3_again),
  284. (c1_anew, c1_anew_again),
  285. (c1_otherstub, c1_otherstub_again),
  286. }
  287. for i in range(len(all_channels)):
  288. for j in range(i + 1, len(all_channels)):
  289. ci, cj = all_channels[i], all_channels[j]
  290. assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
  291. def test_serialize_tuple():
  292. test_pairs = (
  293. ((1, 2, 3), [1, 2, 3]),
  294. (("1", False, 0), ["1", False, 0]),
  295. (("1", False, 0), ("1", 0, 0)),
  296. (("1", b"qq", (2, 5, "0")), ["1", b"qq", (2, 5, "0")]),
  297. )
  298. for first, second in test_pairs:
  299. assert MSGPackSerializer.loads(MSGPackSerializer.dumps(first)) == first
  300. assert MSGPackSerializer.loads(MSGPackSerializer.dumps(second)) == second
  301. assert MSGPackSerializer.dumps(first) != MSGPackSerializer.dumps(second)
  302. def test_split_parts():
  303. tensor = torch.randn(910, 512)
  304. serialized_tensor_part = serialize_torch_tensor(tensor, allow_inplace=False)
  305. chunks1 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 16384))
  306. assert len(chunks1) == int(np.ceil(tensor.numel() * tensor.element_size() / 16384))
  307. chunks2 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000))
  308. assert len(chunks2) == int(np.ceil(tensor.numel() * tensor.element_size() / 10_000))
  309. chunks3 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10**9))
  310. assert len(chunks3) == 1
  311. compressed_tensor_part = serialize_torch_tensor(tensor, CompressionType.FLOAT16, allow_inplace=False)
  312. chunks4 = list(hivemind.utils.split_for_streaming(compressed_tensor_part, 16384))
  313. assert len(chunks4) == int(np.ceil(tensor.numel() * 2 / 16384))
  314. combined1 = hivemind.utils.combine_from_streaming(chunks1)
  315. combined2 = hivemind.utils.combine_from_streaming(iter(chunks2))
  316. combined3 = hivemind.utils.combine_from_streaming(chunks3)
  317. combined4 = hivemind.utils.combine_from_streaming(chunks4)
  318. for combined in combined1, combined2, combined3:
  319. assert torch.allclose(tensor, deserialize_torch_tensor(combined), rtol=1e-5, atol=1e-8)
  320. assert torch.allclose(tensor, deserialize_torch_tensor(combined4), rtol=1e-3, atol=1e-3)
  321. combined_incomplete = hivemind.utils.combine_from_streaming(chunks4[:5])
  322. combined_incomplete2 = hivemind.utils.combine_from_streaming(chunks4[:1])
  323. combined_incomplete3 = hivemind.utils.combine_from_streaming(chunks4[:-1])
  324. for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
  325. with pytest.raises(RuntimeError):
  326. deserialize_torch_tensor(combined)
  327. # note: we rely on this being RuntimeError in hivemind.averaging.allreduce.AllreduceRunner
  328. def test_generic_data_classes():
  329. value_with_exp = ValueWithExpiration(value="string_value", expiration_time=DHTExpiration(10))
  330. assert value_with_exp.value == "string_value" and value_with_exp.expiration_time == DHTExpiration(10)
  331. heap_entry = HeapEntry(expiration_time=DHTExpiration(10), key="string_value")
  332. assert heap_entry.key == "string_value" and heap_entry.expiration_time == DHTExpiration(10)
  333. sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)])
  334. sorted_heap_entries = sorted([HeapEntry(DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
  335. assert all([entry.expiration_time == value for entry, value in zip(sorted_heap_entries, sorted_expirations)])
  336. @pytest.mark.asyncio
  337. async def test_asyncio_utils():
  338. res = [i async for i, item in aenumerate(as_aiter("a", "b", "c"))]
  339. assert res == list(range(len(res)))
  340. num_steps = 0
  341. async for elem in amap_in_executor(lambda x: x**2, as_aiter(*range(100)), max_prefetch=5):
  342. assert elem == num_steps**2
  343. num_steps += 1
  344. assert num_steps == 100
  345. ours = [
  346. elem
  347. async for elem in amap_in_executor(max, as_aiter(*range(7)), as_aiter(*range(-50, 50, 10)), max_prefetch=1)
  348. ]
  349. ref = list(map(max, range(7), range(-50, 50, 10)))
  350. assert ours == ref
  351. ours = [row async for row in azip(as_aiter("a", "b", "c"), as_aiter(1, 2, 3))]
  352. ref = list(zip(["a", "b", "c"], [1, 2, 3]))
  353. assert ours == ref
  354. async def _aiterate():
  355. yield "foo"
  356. yield "bar"
  357. yield "baz"
  358. iterator = _aiterate()
  359. assert (await anext(iterator)) == "foo"
  360. tail = [item async for item in iterator]
  361. assert tail == ["bar", "baz"]
  362. with pytest.raises(StopAsyncIteration):
  363. await anext(iterator)
  364. assert [item async for item in achain(_aiterate(), as_aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
  365. assert await asingle(as_aiter(1)) == 1
  366. with pytest.raises(ValueError):
  367. await asingle(as_aiter())
  368. with pytest.raises(ValueError):
  369. await asingle(as_aiter(1, 2, 3))
  370. assert await afirst(as_aiter(1)) == 1
  371. assert await afirst(as_aiter()) is None
  372. assert await afirst(as_aiter(), -1) == -1
  373. assert await afirst(as_aiter(1, 2, 3)) == 1
  374. async def iterate_with_delays(delays):
  375. for i, delay in enumerate(delays):
  376. await asyncio.sleep(delay)
  377. yield i
  378. async for _ in aiter_with_timeout(iterate_with_delays([0.1] * 5), timeout=0.2):
  379. pass
  380. sleepy_aiter = iterate_with_delays([0.1, 0.1, 0.3, 0.1, 0.1])
  381. num_steps = 0
  382. with pytest.raises(asyncio.TimeoutError):
  383. async for _ in aiter_with_timeout(sleepy_aiter, timeout=0.2):
  384. num_steps += 1
  385. assert num_steps == 2
  386. event = asyncio.Event()
  387. async for i in attach_event_on_finished(iterate_with_delays([0, 0, 0, 0, 0]), event):
  388. assert not event.is_set()
  389. assert event.is_set()
  390. event = asyncio.Event()
  391. sleepy_aiter = iterate_with_delays([0.1, 0.1, 0.3, 0.1, 0.1])
  392. with pytest.raises(asyncio.TimeoutError):
  393. async for _ in attach_event_on_finished(aiter_with_timeout(sleepy_aiter, timeout=0.2), event):
  394. assert not event.is_set()
  395. assert event.is_set()
  396. @pytest.mark.asyncio
  397. async def test_cancel_and_wait():
  398. finished_gracefully = False
  399. async def coro_with_finalizer():
  400. nonlocal finished_gracefully
  401. try:
  402. await asyncio.Event().wait()
  403. except asyncio.CancelledError:
  404. await asyncio.sleep(0.05)
  405. finished_gracefully = True
  406. raise
  407. task = asyncio.create_task(coro_with_finalizer())
  408. await asyncio.sleep(0.05)
  409. assert await cancel_and_wait(task)
  410. assert finished_gracefully
  411. async def coro_with_result():
  412. return 777
  413. async def coro_with_error():
  414. raise ValueError("error")
  415. task_with_result = asyncio.create_task(coro_with_result())
  416. task_with_error = asyncio.create_task(coro_with_error())
  417. await asyncio.sleep(0.05)
  418. assert not await cancel_and_wait(task_with_result)
  419. assert not await cancel_and_wait(task_with_error)
  420. @pytest.mark.asyncio
  421. async def test_async_context():
  422. lock = mp.Lock()
  423. async def coro1():
  424. async with enter_asynchronously(lock):
  425. await asyncio.sleep(0.2)
  426. async def coro2():
  427. await asyncio.sleep(0.1)
  428. async with enter_asynchronously(lock):
  429. await asyncio.sleep(0.1)
  430. await asyncio.wait_for(asyncio.gather(coro1(), coro2()), timeout=0.5)
  431. # running this without enter_asynchronously would deadlock the event loop
  432. def test_batch_tensor_descriptor_msgpack():
  433. tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7))
  434. tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr))
  435. assert (
  436. tensor_descr.size == tensor_descr_roundtrip.size
  437. and tensor_descr.dtype == tensor_descr_roundtrip.dtype
  438. and tensor_descr.layout == tensor_descr_roundtrip.layout
  439. and tensor_descr.device == tensor_descr_roundtrip.device
  440. and tensor_descr.requires_grad == tensor_descr_roundtrip.requires_grad
  441. and tensor_descr.pin_memory == tensor_descr.pin_memory
  442. and tensor_descr.compression == tensor_descr.compression
  443. )
  444. @pytest.mark.parametrize("max_workers", [1, 2, 10])
  445. def test_performance_ema_threadsafe(
  446. max_workers: int,
  447. interval: float = 0.01,
  448. num_updates: int = 100,
  449. alpha: float = 0.05,
  450. bias_power: float = 0.7,
  451. tolerance: float = 0.05,
  452. ):
  453. def run_task(ema):
  454. task_size = random.randint(1, 4)
  455. with ema.update_threadsafe(task_size):
  456. time.sleep(task_size * interval * (0.9 + 0.2 * random.random()))
  457. return task_size
  458. with ThreadPoolExecutor(max_workers) as pool:
  459. ema = PerformanceEMA(alpha=alpha)
  460. start_time = time.perf_counter()
  461. futures = [pool.submit(run_task, ema) for i in range(num_updates)]
  462. total_size = sum(future.result() for future in futures)
  463. end_time = time.perf_counter()
  464. target = total_size / (end_time - start_time)
  465. assert ema.samples_per_second >= (1 - tolerance) * target * max_workers ** (bias_power - 1)
  466. assert ema.samples_per_second <= (1 + tolerance) * target