test_util_modules.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import asyncio
  2. from concurrent.futures import CancelledError
  3. import numpy as np
  4. import pytest
  5. import torch
  6. from hivemind.proto.dht_pb2_grpc import DHTStub
  7. from hivemind.proto.runtime_pb2 import CompressionType
  8. from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
  9. import hivemind
  10. from hivemind.utils import MSGPackSerializer
  11. from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
  12. from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
  13. from hivemind.utils.mpfuture import FutureStateError
  14. def test_mpfuture_result():
  15. f1, f2 = hivemind.MPFuture.make_pair()
  16. f1.set_result(321)
  17. assert f2.result() == 321
  18. assert f1.result() == 321
  19. for future in [f1, f2]:
  20. with pytest.raises(FutureStateError):
  21. future.set_result(123)
  22. with pytest.raises(FutureStateError):
  23. future.set_exception(ValueError())
  24. assert future.cancel() is False
  25. assert future.done() and not future.running() and not future.cancelled()
  26. f1, f2 = hivemind.MPFuture.make_pair()
  27. with pytest.raises(TimeoutError):
  28. f1.result(timeout=1e-3)
  29. f2.set_result(['abacaba', 123])
  30. assert f1.result() == ['abacaba', 123]
  31. def test_mpfuture_exception():
  32. f1, f2 = hivemind.MPFuture.make_pair()
  33. with pytest.raises(TimeoutError):
  34. f1.exception(timeout=1e-3)
  35. f2.set_exception(NotImplementedError())
  36. for future in [f1, f2]:
  37. assert isinstance(future.exception(), NotImplementedError)
  38. with pytest.raises(NotImplementedError):
  39. future.result()
  40. assert future.cancel() is False
  41. assert future.done() and not future.running() and not future.cancelled()
  42. def test_mpfuture_cancel():
  43. f1, f2 = hivemind.MPFuture.make_pair()
  44. assert not f2.cancelled()
  45. f1.cancel()
  46. for future in [f1, f2]:
  47. with pytest.raises(CancelledError):
  48. future.result()
  49. with pytest.raises(CancelledError):
  50. future.exception()
  51. with pytest.raises(FutureStateError):
  52. future.set_result(123)
  53. with pytest.raises(FutureStateError):
  54. future.set_exception(NotImplementedError())
  55. assert future.cancelled() and future.done() and not future.running()
  56. def test_mpfuture_status():
  57. f1, f2 = hivemind.MPFuture.make_pair()
  58. assert f1.set_running_or_notify_cancel() is True
  59. for future in [f1, f2]:
  60. assert future.running() and not future.done() and not future.cancelled()
  61. with pytest.raises(RuntimeError):
  62. future.set_running_or_notify_cancel()
  63. f2.cancel()
  64. for future in [f1, f2]:
  65. assert not future.running() and future.done() and future.cancelled()
  66. assert future.set_running_or_notify_cancel() is False
  67. f1, f2 = hivemind.MPFuture.make_pair()
  68. f1.cancel()
  69. for future in [f1, f2]:
  70. assert future.set_running_or_notify_cancel() is False
  71. @pytest.mark.asyncio
  72. async def test_await_mpfuture():
  73. # await result
  74. f1, f2 = hivemind.MPFuture.make_pair()
  75. async def wait_and_assign():
  76. assert f2.set_running_or_notify_cancel() is True
  77. await asyncio.sleep(0.1)
  78. f2.set_result((123, 'ololo'))
  79. asyncio.create_task(wait_and_assign())
  80. for future in [f1, f2]:
  81. res = await future
  82. assert res == (123, 'ololo')
  83. # await cancel
  84. f1, f2 = hivemind.MPFuture.make_pair()
  85. async def wait_and_cancel():
  86. await asyncio.sleep(0.1)
  87. f1.cancel()
  88. asyncio.create_task(wait_and_cancel())
  89. for future in [f1, f2]:
  90. with pytest.raises(CancelledError):
  91. await future
  92. # await exception
  93. f1, f2 = hivemind.MPFuture.make_pair()
  94. async def wait_and_raise():
  95. await asyncio.sleep(0.1)
  96. f1.set_exception(SystemError())
  97. asyncio.create_task(wait_and_raise())
  98. for future in [f1, f2]:
  99. with pytest.raises(SystemError):
  100. await future
  101. def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
  102. torch.manual_seed(0)
  103. X = torch.randn(*size)
  104. assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
  105. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
  106. assert error.square().mean() < alpha
  107. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
  108. assert error.square().mean() < alpha
  109. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
  110. assert error.square().mean() < beta
  111. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
  112. assert error.square().mean() < beta
  113. zeros = torch.zeros(5,5)
  114. for compression_type in CompressionType.values():
  115. assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
  116. @pytest.mark.forked
  117. @pytest.mark.asyncio
  118. async def test_channel_cache():
  119. hivemind.ChannelCache.MAXIMUM_CHANNELS = 3
  120. hivemind.ChannelCache.EVICTION_PERIOD_SECONDS = 0.1
  121. c1 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
  122. c2 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
  123. c3 = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
  124. c3_again = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
  125. c1_again = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
  126. c4 = hivemind.ChannelCache.get_stub('localhost:1339', DHTStub, aio=True)
  127. c2_anew = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
  128. c1_yetagain = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
  129. await asyncio.sleep(0.2)
  130. c1_anew = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
  131. c1_anew_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
  132. c1_otherstub = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=ConnectionHandlerStub)
  133. await asyncio.sleep(0.05)
  134. c1_otherstub_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False,
  135. stub_type=ConnectionHandlerStub)
  136. all_channels = [c1, c2, c3, c4, c3_again, c1_again, c2_anew, c1_yetagain, c1_anew, c1_anew_again, c1_otherstub]
  137. assert all(isinstance(c, DHTStub) for c in all_channels[:-1])
  138. assert isinstance(all_channels[-1], ConnectionHandlerStub)
  139. assert 'aio' in repr(c2.rpc_find)
  140. assert 'aio' not in repr(c1.rpc_find)
  141. duplicates = {(c1, c1_again), (c1, c1_yetagain), (c1_again, c1_yetagain), (c3, c3_again),
  142. (c1_anew, c1_anew_again), (c1_otherstub, c1_otherstub_again)}
  143. for i in range(len(all_channels)):
  144. for j in range(i + 1, len(all_channels)):
  145. ci, cj = all_channels[i], all_channels[j]
  146. assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
  147. def test_serialize_tensor():
  148. tensor = torch.randn(512, 12288)
  149. serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
  150. for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
  151. chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
  152. assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
  153. restored = hivemind.combine_from_streaming(chunks)
  154. assert torch.allclose(deserialize_torch_tensor(restored), tensor)
  155. chunk_size = 30 * 1024
  156. serialized_tensor = serialize_torch_tensor(tensor, CompressionType.FLOAT16)
  157. chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
  158. assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
  159. restored = hivemind.combine_from_streaming(chunks)
  160. assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=0, atol=1e-2)
  161. tensor = torch.randint(0, 100, (512, 1, 1))
  162. serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
  163. chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
  164. assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
  165. restored = hivemind.combine_from_streaming(chunks)
  166. assert torch.allclose(deserialize_torch_tensor(restored), tensor)
  167. scalar = torch.tensor(1.)
  168. serialized_scalar = serialize_torch_tensor(scalar, CompressionType.NONE)
  169. assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
  170. serialized_scalar = serialize_torch_tensor(scalar, CompressionType.FLOAT16)
  171. assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
  172. def test_serialize_tuple():
  173. test_pairs = (
  174. ((1, 2, 3), [1, 2, 3]),
  175. (('1', False, 0), ['1', False, 0]),
  176. (('1', False, 0), ('1', 0, 0)),
  177. (('1', b'qq', (2, 5, '0')), ['1', b'qq', (2, 5, '0')]),
  178. )
  179. for first, second in test_pairs:
  180. assert MSGPackSerializer.loads(MSGPackSerializer.dumps(first)) == first
  181. assert MSGPackSerializer.loads(MSGPackSerializer.dumps(second)) == second
  182. assert MSGPackSerializer.dumps(first) != MSGPackSerializer.dumps(second)
  183. def test_split_parts():
  184. tensor = torch.randn(910, 512)
  185. serialized_tensor_part = serialize_torch_tensor(tensor, allow_inplace=False)
  186. chunks1 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 16384))
  187. assert len(chunks1) == int(np.ceil(tensor.numel() * tensor.element_size() / 16384))
  188. chunks2 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000))
  189. assert len(chunks2) == int(np.ceil(tensor.numel() * tensor.element_size() / 10_000))
  190. chunks3 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10 ** 9))
  191. assert len(chunks3) == 1
  192. compressed_tensor_part = serialize_torch_tensor(tensor, CompressionType.FLOAT16, allow_inplace=False)
  193. chunks4 = list(hivemind.utils.split_for_streaming(compressed_tensor_part, 16384))
  194. assert len(chunks4) == int(np.ceil(tensor.numel() * 2 / 16384))
  195. combined1 = hivemind.utils.combine_from_streaming(chunks1)
  196. combined2 = hivemind.utils.combine_from_streaming(iter(chunks2))
  197. combined3 = hivemind.utils.combine_from_streaming(chunks3)
  198. combined4 = hivemind.utils.combine_from_streaming(chunks4)
  199. for combined in combined1, combined2, combined3:
  200. assert torch.allclose(tensor, deserialize_torch_tensor(combined), rtol=1e-5, atol=1e-8)
  201. assert torch.allclose(tensor, deserialize_torch_tensor(combined4), rtol=1e-3, atol=1e-3)
  202. combined_incomplete = hivemind.utils.combine_from_streaming(chunks4[:5])
  203. combined_incomplete2 = hivemind.utils.combine_from_streaming(chunks4[:1])
  204. combined_incomplete3 = hivemind.utils.combine_from_streaming(chunks4[:-1])
  205. for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
  206. with pytest.raises(RuntimeError):
  207. deserialize_torch_tensor(combined)
  208. # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceRunner
  209. def test_generic_data_classes():
  210. from hivemind.utils import ValueWithExpiration, HeapEntry, DHTExpiration
  211. value_with_exp = ValueWithExpiration(value="string_value", expiration_time=DHTExpiration(10))
  212. assert value_with_exp.value == "string_value" and value_with_exp.expiration_time == DHTExpiration(10)
  213. heap_entry = HeapEntry(expiration_time=DHTExpiration(10), key="string_value")
  214. assert heap_entry.key == "string_value" and heap_entry.expiration_time == DHTExpiration(10)
  215. sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)])
  216. sorted_heap_entries = sorted([HeapEntry(DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
  217. assert all([entry.expiration_time == value for entry, value in zip(sorted_heap_entries, sorted_expirations)])
  218. @pytest.mark.asyncio
  219. async def test_asyncio_utils():
  220. res = [i async for i, item in aenumerate(aiter('a', 'b', 'c'))]
  221. assert res == list(range(len(res)))
  222. num_steps = 0
  223. async for elem in amap_in_executor(lambda x: x ** 2, aiter(*range(100)), max_prefetch=5):
  224. assert elem == num_steps ** 2
  225. num_steps += 1
  226. assert num_steps == 100
  227. ours = [elem async for elem in amap_in_executor(max, aiter(*range(7)), aiter(*range(-50, 50, 10)), max_prefetch=1)]
  228. ref = list(map(max, range(7), range(-50, 50, 10)))
  229. assert ours == ref
  230. ours = [row async for row in azip(aiter('a', 'b', 'c'), aiter(1, 2, 3))]
  231. ref = list(zip(['a', 'b', 'c'], [1, 2, 3]))
  232. assert ours == ref
  233. async def _aiterate():
  234. yield 'foo'
  235. yield 'bar'
  236. yield 'baz'
  237. iterator = _aiterate()
  238. assert (await anext(iterator)) == 'foo'
  239. tail = [item async for item in iterator]
  240. assert tail == ['bar', 'baz']
  241. with pytest.raises(StopAsyncIteration):
  242. await anext(iterator)
  243. assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ['foo', 'bar', 'baz'] + list(range(5))