test_util_modules.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. import asyncio
  2. import concurrent.futures
  3. import multiprocessing as mp
  4. import random
  5. import time
  6. import numpy as np
  7. import pytest
  8. import torch
  9. import hivemind
  10. from hivemind.proto.dht_pb2_grpc import DHTStub
  11. from hivemind.proto.runtime_pb2 import CompressionType
  12. from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
  13. from hivemind.utils import MSGPackSerializer
  14. from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
  15. from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
  16. from hivemind.utils.mpfuture import InvalidStateError
  17. @pytest.mark.forked
  18. def test_mpfuture_result():
  19. future = hivemind.MPFuture()
  20. def _proc(future):
  21. with pytest.raises(RuntimeError):
  22. future.result() # only creator process can await result
  23. future.set_result(321)
  24. p = mp.Process(target=_proc, args=(future,))
  25. p.start()
  26. p.join()
  27. assert future.result() == 321
  28. assert future.exception() is None
  29. assert future.cancel() is False
  30. assert future.done() and not future.running() and not future.cancelled()
  31. future = hivemind.MPFuture()
  32. with pytest.raises(concurrent.futures.TimeoutError):
  33. future.result(timeout=1e-3)
  34. future.set_result(['abacaba', 123])
  35. assert future.result() == ['abacaba', 123]
  36. @pytest.mark.forked
  37. def test_mpfuture_exception():
  38. future = hivemind.MPFuture()
  39. with pytest.raises(concurrent.futures.TimeoutError):
  40. future.exception(timeout=1e-3)
  41. def _proc(future):
  42. future.set_exception(NotImplementedError())
  43. p = mp.Process(target=_proc, args=(future,))
  44. p.start()
  45. p.join()
  46. assert isinstance(future.exception(), NotImplementedError)
  47. with pytest.raises(NotImplementedError):
  48. future.result()
  49. assert future.cancel() is False
  50. assert future.done() and not future.running() and not future.cancelled()
  51. @pytest.mark.forked
  52. def test_mpfuture_cancel():
  53. future = hivemind.MPFuture()
  54. assert not future.cancelled()
  55. future.cancel()
  56. evt = mp.Event()
  57. def _proc():
  58. with pytest.raises(concurrent.futures.CancelledError):
  59. future.result()
  60. with pytest.raises(concurrent.futures.CancelledError):
  61. future.exception()
  62. with pytest.raises(InvalidStateError):
  63. future.set_result(123)
  64. with pytest.raises(InvalidStateError):
  65. future.set_exception(NotImplementedError())
  66. assert future.cancelled() and future.done() and not future.running()
  67. evt.set()
  68. p = mp.Process(target=_proc)
  69. p.start()
  70. p.join()
  71. assert evt.is_set()
  72. @pytest.mark.forked
  73. def test_mpfuture_status():
  74. evt = mp.Event()
  75. future = hivemind.MPFuture()
  76. def _proc1(future):
  77. assert future.set_running_or_notify_cancel() is True
  78. evt.set()
  79. p = mp.Process(target=_proc1, args=(future,))
  80. p.start()
  81. p.join()
  82. assert evt.is_set()
  83. evt.clear()
  84. assert future.running() and not future.done() and not future.cancelled()
  85. with pytest.raises(InvalidStateError):
  86. future.set_running_or_notify_cancel()
  87. future = hivemind.MPFuture()
  88. assert future.cancel()
  89. def _proc2(future):
  90. assert not future.running() and future.done() and future.cancelled()
  91. assert future.set_running_or_notify_cancel() is False
  92. evt.set()
  93. p = mp.Process(target=_proc2, args=(future,))
  94. p.start()
  95. p.join()
  96. evt.set()
  97. future2 = hivemind.MPFuture()
  98. future2.cancel()
  99. assert future2.set_running_or_notify_cancel() is False
  100. @pytest.mark.asyncio
  101. async def test_await_mpfuture():
  102. # await result from the same process, but a different coroutine
  103. f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
  104. async def wait_and_assign_async():
  105. assert f2.set_running_or_notify_cancel() is True
  106. await asyncio.sleep(0.1)
  107. f1.set_result((123, 'ololo'))
  108. f2.set_result((456, 'pyshpysh'))
  109. asyncio.create_task(wait_and_assign_async())
  110. assert (await asyncio.gather(f1, f2)) == [(123, 'ololo'), (456, 'pyshpysh')]
  111. # await result from separate processes
  112. f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
  113. def wait_and_assign(future, value):
  114. time.sleep(0.1 * random.random())
  115. future.set_result(value)
  116. p1 = mp.Process(target=wait_and_assign, args=(f1, 'abc'))
  117. p2 = mp.Process(target=wait_and_assign, args=(f2, 'def'))
  118. for p in p1, p2:
  119. p.start()
  120. assert (await asyncio.gather(f1, f2)) == ['abc', 'def']
  121. for p in p1, p2:
  122. p.join()
  123. # await cancel
  124. f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
  125. def wait_and_cancel():
  126. time.sleep(0.01)
  127. f2.set_result(123456)
  128. time.sleep(0.1)
  129. f1.cancel()
  130. p = mp.Process(target=wait_and_cancel)
  131. p.start()
  132. with pytest.raises(asyncio.CancelledError):
  133. # note: it is intended that MPFuture raises Cancel
  134. await asyncio.gather(f1, f2)
  135. p.join()
  136. # await exception
  137. f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
  138. def wait_and_raise():
  139. time.sleep(0.01)
  140. f2.set_result(123456)
  141. time.sleep(0.1)
  142. f1.set_exception(ValueError('we messed up'))
  143. p = mp.Process(target=wait_and_raise)
  144. p.start()
  145. with pytest.raises(ValueError):
  146. # note: it is intended that MPFuture raises Cancel
  147. await asyncio.gather(f1, f2)
  148. p.join()
  149. @pytest.mark.forked
  150. def test_mpfuture_bidirectional():
  151. evt = mp.Event()
  152. future_from_main = hivemind.MPFuture()
  153. def _future_creator():
  154. future_from_fork = hivemind.MPFuture()
  155. future_from_main.set_result(('abc', future_from_fork))
  156. if future_from_fork.result() == ['we', 'need', 'to', 'go', 'deeper']:
  157. evt.set()
  158. p = mp.Process(target=_future_creator)
  159. p.start()
  160. out = future_from_main.result()
  161. assert isinstance(out[1], hivemind.MPFuture)
  162. out[1].set_result(['we', 'need', 'to', 'go', 'deeper'])
  163. p.join()
  164. assert evt.is_set()
  165. @pytest.mark.forked
  166. def test_mpfuture_done_callback():
  167. receiver, sender = mp.Pipe(duplex=False)
  168. events = [mp.Event() for _ in range(5)]
  169. def _future_creator():
  170. future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture()
  171. def _check_result_and_set(future):
  172. assert future.done()
  173. assert future.result() == 123
  174. events[0].set()
  175. future1.add_done_callback(_check_result_and_set)
  176. future1.add_done_callback(lambda future: events[1].set())
  177. future2.add_done_callback(lambda future: events[2].set())
  178. future3.add_done_callback(lambda future: events[3].set())
  179. sender.send((future1, future2))
  180. future2.cancel() # trigger future2 callback from the same process
  181. events[0].wait()
  182. future1.add_done_callback(lambda future: events[4].set()) # schedule callback after future1 is already finished
  183. p = mp.Process(target=_future_creator)
  184. p.start()
  185. future1, future2 = receiver.recv()
  186. future1.set_result(123)
  187. with pytest.raises(RuntimeError):
  188. future1.add_done_callback(lambda future: (1, 2, 3))
  189. p.join()
  190. events[0].wait(1)
  191. events[1].wait(1)
  192. assert future1.done() and not future1.cancelled()
  193. assert future2.done() and future2.cancelled()
  194. assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
  195. assert not events[3].is_set()
  196. @pytest.mark.forked
  197. def test_many_futures():
  198. evt = mp.Event()
  199. receiver, sender = mp.Pipe()
  200. main_futures = [hivemind.MPFuture() for _ in range(1000)]
  201. assert len(hivemind.MPFuture._active_futures) == 1000
  202. def _run_peer():
  203. fork_futures = [hivemind.MPFuture() for _ in range(500)]
  204. assert len(hivemind.MPFuture._active_futures) == 500
  205. for i, future in enumerate(random.sample(main_futures, 300)):
  206. if random.random() < 0.5:
  207. future.set_result(i)
  208. else:
  209. future.set_exception(ValueError(f"{i}"))
  210. sender.send(fork_futures[:-100])
  211. for future in fork_futures[-100:]:
  212. future.cancel()
  213. evt.wait()
  214. assert len(hivemind.MPFuture._active_futures) == 200
  215. for future in fork_futures:
  216. future.cancel()
  217. assert len(hivemind.MPFuture._active_futures) == 0
  218. p = mp.Process(target=_run_peer)
  219. p.start()
  220. some_fork_futures = receiver.recv()
  221. assert len(hivemind.MPFuture._active_futures) == 700
  222. for future in some_fork_futures:
  223. future.set_running_or_notify_cancel()
  224. for future in random.sample(some_fork_futures, 200):
  225. future.set_result(321)
  226. time.sleep(0.5)
  227. evt.set()
  228. for future in main_futures:
  229. future.cancel()
  230. assert len(hivemind.MPFuture._active_futures) == 0
  231. p.join()
  232. def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
  233. torch.manual_seed(0)
  234. X = torch.randn(*size)
  235. assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
  236. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
  237. assert error.square().mean() < alpha
  238. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
  239. assert error.square().mean() < alpha
  240. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
  241. assert error.square().mean() < beta
  242. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
  243. assert error.square().mean() < beta
  244. zeros = torch.zeros(5, 5)
  245. for compression_type in CompressionType.values():
  246. assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
  247. @pytest.mark.forked
  248. @pytest.mark.asyncio
  249. async def test_channel_cache():
  250. hivemind.ChannelCache.MAXIMUM_CHANNELS = 3
  251. hivemind.ChannelCache.EVICTION_PERIOD_SECONDS = 0.1
  252. c1 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
  253. c2 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
  254. c3 = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
  255. c3_again = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
  256. c1_again = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
  257. c4 = hivemind.ChannelCache.get_stub('localhost:1339', DHTStub, aio=True)
  258. c2_anew = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
  259. c1_yetagain = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
  260. await asyncio.sleep(0.2)
  261. c1_anew = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
  262. c1_anew_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
  263. c1_otherstub = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=ConnectionHandlerStub)
  264. await asyncio.sleep(0.05)
  265. c1_otherstub_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False,
  266. stub_type=ConnectionHandlerStub)
  267. all_channels = [c1, c2, c3, c4, c3_again, c1_again, c2_anew, c1_yetagain, c1_anew, c1_anew_again, c1_otherstub]
  268. assert all(isinstance(c, DHTStub) for c in all_channels[:-1])
  269. assert isinstance(all_channels[-1], ConnectionHandlerStub)
  270. assert 'aio' in repr(c2.rpc_find)
  271. assert 'aio' not in repr(c1.rpc_find)
  272. duplicates = {(c1, c1_again), (c1, c1_yetagain), (c1_again, c1_yetagain), (c3, c3_again),
  273. (c1_anew, c1_anew_again), (c1_otherstub, c1_otherstub_again)}
  274. for i in range(len(all_channels)):
  275. for j in range(i + 1, len(all_channels)):
  276. ci, cj = all_channels[i], all_channels[j]
  277. assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
  278. def test_serialize_tensor():
  279. tensor = torch.randn(512, 12288)
  280. serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
  281. for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
  282. chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
  283. assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
  284. restored = hivemind.combine_from_streaming(chunks)
  285. assert torch.allclose(deserialize_torch_tensor(restored), tensor)
  286. chunk_size = 30 * 1024
  287. serialized_tensor = serialize_torch_tensor(tensor, CompressionType.FLOAT16)
  288. chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
  289. assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
  290. restored = hivemind.combine_from_streaming(chunks)
  291. assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=0, atol=1e-2)
  292. tensor = torch.randint(0, 100, (512, 1, 1))
  293. serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
  294. chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
  295. assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
  296. restored = hivemind.combine_from_streaming(chunks)
  297. assert torch.allclose(deserialize_torch_tensor(restored), tensor)
  298. scalar = torch.tensor(1.)
  299. serialized_scalar = serialize_torch_tensor(scalar, CompressionType.NONE)
  300. assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
  301. serialized_scalar = serialize_torch_tensor(scalar, CompressionType.FLOAT16)
  302. assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
  303. def test_serialize_tuple():
  304. test_pairs = (
  305. ((1, 2, 3), [1, 2, 3]),
  306. (('1', False, 0), ['1', False, 0]),
  307. (('1', False, 0), ('1', 0, 0)),
  308. (('1', b'qq', (2, 5, '0')), ['1', b'qq', (2, 5, '0')]),
  309. )
  310. for first, second in test_pairs:
  311. assert MSGPackSerializer.loads(MSGPackSerializer.dumps(first)) == first
  312. assert MSGPackSerializer.loads(MSGPackSerializer.dumps(second)) == second
  313. assert MSGPackSerializer.dumps(first) != MSGPackSerializer.dumps(second)
  314. def test_split_parts():
  315. tensor = torch.randn(910, 512)
  316. serialized_tensor_part = serialize_torch_tensor(tensor, allow_inplace=False)
  317. chunks1 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 16384))
  318. assert len(chunks1) == int(np.ceil(tensor.numel() * tensor.element_size() / 16384))
  319. chunks2 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000))
  320. assert len(chunks2) == int(np.ceil(tensor.numel() * tensor.element_size() / 10_000))
  321. chunks3 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10 ** 9))
  322. assert len(chunks3) == 1
  323. compressed_tensor_part = serialize_torch_tensor(tensor, CompressionType.FLOAT16, allow_inplace=False)
  324. chunks4 = list(hivemind.utils.split_for_streaming(compressed_tensor_part, 16384))
  325. assert len(chunks4) == int(np.ceil(tensor.numel() * 2 / 16384))
  326. combined1 = hivemind.utils.combine_from_streaming(chunks1)
  327. combined2 = hivemind.utils.combine_from_streaming(iter(chunks2))
  328. combined3 = hivemind.utils.combine_from_streaming(chunks3)
  329. combined4 = hivemind.utils.combine_from_streaming(chunks4)
  330. for combined in combined1, combined2, combined3:
  331. assert torch.allclose(tensor, deserialize_torch_tensor(combined), rtol=1e-5, atol=1e-8)
  332. assert torch.allclose(tensor, deserialize_torch_tensor(combined4), rtol=1e-3, atol=1e-3)
  333. combined_incomplete = hivemind.utils.combine_from_streaming(chunks4[:5])
  334. combined_incomplete2 = hivemind.utils.combine_from_streaming(chunks4[:1])
  335. combined_incomplete3 = hivemind.utils.combine_from_streaming(chunks4[:-1])
  336. for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
  337. with pytest.raises(RuntimeError):
  338. deserialize_torch_tensor(combined)
  339. # note: we rely on this being RuntimeError in hivemind.averaging.allreduce.AllreduceRunner
  340. def test_generic_data_classes():
  341. from hivemind.utils import ValueWithExpiration, HeapEntry, DHTExpiration
  342. value_with_exp = ValueWithExpiration(value="string_value", expiration_time=DHTExpiration(10))
  343. assert value_with_exp.value == "string_value" and value_with_exp.expiration_time == DHTExpiration(10)
  344. heap_entry = HeapEntry(expiration_time=DHTExpiration(10), key="string_value")
  345. assert heap_entry.key == "string_value" and heap_entry.expiration_time == DHTExpiration(10)
  346. sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)])
  347. sorted_heap_entries = sorted([HeapEntry(DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
  348. assert all([entry.expiration_time == value for entry, value in zip(sorted_heap_entries, sorted_expirations)])
  349. @pytest.mark.asyncio
  350. async def test_asyncio_utils():
  351. res = [i async for i, item in aenumerate(aiter('a', 'b', 'c'))]
  352. assert res == list(range(len(res)))
  353. num_steps = 0
  354. async for elem in amap_in_executor(lambda x: x ** 2, aiter(*range(100)), max_prefetch=5):
  355. assert elem == num_steps ** 2
  356. num_steps += 1
  357. assert num_steps == 100
  358. ours = [elem async for elem in amap_in_executor(max, aiter(*range(7)), aiter(*range(-50, 50, 10)), max_prefetch=1)]
  359. ref = list(map(max, range(7), range(-50, 50, 10)))
  360. assert ours == ref
  361. ours = [row async for row in azip(aiter('a', 'b', 'c'), aiter(1, 2, 3))]
  362. ref = list(zip(['a', 'b', 'c'], [1, 2, 3]))
  363. assert ours == ref
  364. async def _aiterate():
  365. yield 'foo'
  366. yield 'bar'
  367. yield 'baz'
  368. iterator = _aiterate()
  369. assert (await anext(iterator)) == 'foo'
  370. tail = [item async for item in iterator]
  371. assert tail == ['bar', 'baz']
  372. with pytest.raises(StopAsyncIteration):
  373. await anext(iterator)
  374. assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ['foo', 'bar', 'baz'] + list(range(5))