test_util_modules.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import asyncio
  2. import torch
  3. import numpy as np
  4. import pytest
  5. import hivemind
  6. from hivemind.proto.dht_pb2_grpc import DHTStub
  7. from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
  8. from hivemind.utils import MSGPackSerializer
  9. from concurrent.futures import CancelledError
  10. def test_mpfuture_result():
  11. f1, f2 = hivemind.MPFuture.make_pair()
  12. f1.set_result(321)
  13. assert f2.result() == 321
  14. assert f1.result() == 321
  15. for future in [f1, f2]:
  16. with pytest.raises(RuntimeError):
  17. future.set_result(123)
  18. with pytest.raises(RuntimeError):
  19. future.set_exception(ValueError())
  20. assert future.cancel() is False
  21. assert future.done() and not future.running() and not future.cancelled()
  22. f1, f2 = hivemind.MPFuture.make_pair()
  23. with pytest.raises(TimeoutError):
  24. f1.result(timeout=1e-3)
  25. f2.set_result(['abacaba', 123])
  26. assert f1.result() == ['abacaba', 123]
  27. def test_mpfuture_exception():
  28. f1, f2 = hivemind.MPFuture.make_pair()
  29. with pytest.raises(TimeoutError):
  30. f1.exception(timeout=1e-3)
  31. f2.set_exception(NotImplementedError())
  32. for future in [f1, f2]:
  33. assert isinstance(future.exception(), NotImplementedError)
  34. with pytest.raises(NotImplementedError):
  35. future.result()
  36. assert future.cancel() is False
  37. assert future.done() and not future.running() and not future.cancelled()
  38. def test_mpfuture_cancel():
  39. f1, f2 = hivemind.MPFuture.make_pair()
  40. assert not f2.cancelled()
  41. f1.cancel()
  42. for future in [f1, f2]:
  43. with pytest.raises(CancelledError):
  44. future.result()
  45. with pytest.raises(CancelledError):
  46. future.exception()
  47. with pytest.raises(RuntimeError):
  48. future.set_result(123)
  49. with pytest.raises(RuntimeError):
  50. future.set_exception(NotImplementedError())
  51. assert future.cancelled() and future.done() and not future.running()
  52. def test_mpfuture_status():
  53. f1, f2 = hivemind.MPFuture.make_pair()
  54. assert f1.set_running_or_notify_cancel() is True
  55. for future in [f1, f2]:
  56. assert future.running() and not future.done() and not future.cancelled()
  57. with pytest.raises(RuntimeError):
  58. future.set_running_or_notify_cancel()
  59. f2.cancel()
  60. for future in [f1, f2]:
  61. assert not future.running() and future.done() and future.cancelled()
  62. assert future.set_running_or_notify_cancel() is False
  63. f1, f2 = hivemind.MPFuture.make_pair()
  64. f1.cancel()
  65. for future in [f1, f2]:
  66. assert future.set_running_or_notify_cancel() is False
  67. @pytest.mark.asyncio
  68. async def test_await_mpfuture():
  69. # await result
  70. f1, f2 = hivemind.MPFuture.make_pair()
  71. async def wait_and_assign():
  72. assert f2.set_running_or_notify_cancel() is True
  73. await asyncio.sleep(0.1)
  74. f2.set_result((123, 'ololo'))
  75. asyncio.create_task(wait_and_assign())
  76. for future in [f1, f2]:
  77. res = await future
  78. assert res == (123, 'ololo')
  79. # await cancel
  80. f1, f2 = hivemind.MPFuture.make_pair()
  81. async def wait_and_cancel():
  82. await asyncio.sleep(0.1)
  83. f1.cancel()
  84. asyncio.create_task(wait_and_cancel())
  85. for future in [f1, f2]:
  86. with pytest.raises(CancelledError):
  87. await future
  88. # await exception
  89. f1, f2 = hivemind.MPFuture.make_pair()
  90. async def wait_and_raise():
  91. await asyncio.sleep(0.1)
  92. f1.set_exception(SystemError())
  93. asyncio.create_task(wait_and_raise())
  94. for future in [f1, f2]:
  95. with pytest.raises(SystemError):
  96. await future
  97. def test_vector_compression(size=(128, 128, 64), alpha=5e-08):
  98. torch.manual_seed(0)
  99. from hivemind.proto.runtime_pb2 import CompressionType
  100. from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor
  101. X = torch.randn(*size)
  102. assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
  103. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_LAST_AXIS_FLOAT16))-X
  104. assert error.square().mean() < alpha
  105. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
  106. assert error.square().mean() < alpha
  107. @pytest.mark.forked
  108. @pytest.mark.asyncio
  109. async def test_channel_cache():
  110. hivemind.ChannelCache.MAXIMUM_CHANNELS = 3
  111. hivemind.ChannelCache.EVICTION_PERIOD_SECONDS = 0.1
  112. c1 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
  113. c2 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
  114. c3 = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
  115. c3_again = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
  116. c1_again = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
  117. c4 = hivemind.ChannelCache.get_stub('localhost:1339', DHTStub, aio=True)
  118. c2_anew = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
  119. c1_yetagain = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
  120. await asyncio.sleep(0.2)
  121. c1_anew = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
  122. c1_anew_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
  123. c1_otherstub = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=ConnectionHandlerStub)
  124. await asyncio.sleep(0.05)
  125. c1_otherstub_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False,
  126. stub_type=ConnectionHandlerStub)
  127. all_channels = [c1, c2, c3, c4, c3_again, c1_again, c2_anew, c1_yetagain, c1_anew, c1_anew_again, c1_otherstub]
  128. assert all(isinstance(c, DHTStub) for c in all_channels[:-1])
  129. assert isinstance(all_channels[-1], ConnectionHandlerStub)
  130. assert 'aio' in repr(c2.rpc_find)
  131. assert 'aio' not in repr(c1.rpc_find)
  132. duplicates = {(c1, c1_again), (c1, c1_yetagain), (c1_again, c1_yetagain), (c3, c3_again),
  133. (c1_anew, c1_anew_again), (c1_otherstub, c1_otherstub_again)}
  134. for i in range(len(all_channels)):
  135. for j in range(i + 1, len(all_channels)):
  136. ci, cj = all_channels[i], all_channels[j]
  137. assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
  138. def test_serialize_tensor():
  139. tensor = torch.randn(512, 12288)
  140. serialized_tensor = hivemind.serialize_torch_tensor(tensor, hivemind.CompressionType.NONE)
  141. for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
  142. chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
  143. assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
  144. restored = hivemind.combine_from_streaming(chunks)
  145. assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)
  146. chunk_size = 30 * 1024
  147. serialized_tensor = hivemind.serialize_torch_tensor(tensor, hivemind.CompressionType.FLOAT16)
  148. chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
  149. assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
  150. restored = hivemind.combine_from_streaming(chunks)
  151. assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor, rtol=0, atol=1e-2)
  152. tensor = torch.randint(0, 100, (512, 1, 1))
  153. serialized_tensor = hivemind.serialize_torch_tensor(tensor, hivemind.CompressionType.NONE)
  154. chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
  155. assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
  156. restored = hivemind.combine_from_streaming(chunks)
  157. assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)
  158. def test_serialize_tuple():
  159. test_pairs = (
  160. ((1, 2, 3), [1, 2, 3]),
  161. (('1', False, 0), ['1', False, 0]),
  162. (('1', False, 0), ('1', 0, 0)),
  163. (('1', b'qq', (2, 5, '0')), ['1', b'qq', (2, 5, '0')]),
  164. )
  165. for first, second in test_pairs:
  166. assert MSGPackSerializer.loads(MSGPackSerializer.dumps(first)) == first
  167. assert MSGPackSerializer.loads(MSGPackSerializer.dumps(second)) == second
  168. assert MSGPackSerializer.dumps(first) != MSGPackSerializer.dumps(second)
  169. def test_split_parts():
  170. tensor = torch.randn(910, 512)
  171. serialized_tensor_part = hivemind.utils.serialize_torch_tensor(tensor, allow_inplace=False)
  172. chunks1 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 16384))
  173. assert len(chunks1) == int(np.ceil(tensor.numel() * tensor.element_size() / 16384))
  174. chunks2 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000))
  175. assert len(chunks2) == int(np.ceil(tensor.numel() * tensor.element_size() / 10_000))
  176. chunks3 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10 ** 9))
  177. assert len(chunks3) == 1
  178. compressed_tensor_part = hivemind.utils.serialize_torch_tensor(tensor, hivemind.CompressionType.FLOAT16,
  179. allow_inplace=False)
  180. chunks4 = list(hivemind.utils.split_for_streaming(compressed_tensor_part, 16384))
  181. assert len(chunks4) == int(np.ceil(tensor.numel() * 2 / 16384))
  182. combined1 = hivemind.utils.combine_from_streaming(chunks1)
  183. combined2 = hivemind.utils.combine_from_streaming(iter(chunks2))
  184. combined3 = hivemind.utils.combine_from_streaming(chunks3)
  185. combined4 = hivemind.utils.combine_from_streaming(chunks4)
  186. for combined in combined1, combined2, combined3:
  187. assert torch.allclose(tensor, hivemind.deserialize_torch_tensor(combined), rtol=1e-5, atol=1e-8)
  188. assert torch.allclose(tensor, hivemind.deserialize_torch_tensor(combined4), rtol=1e-3, atol=1e-3)
  189. combined_incomplete = hivemind.utils.combine_from_streaming(chunks4[:5])
  190. combined_incomplete2 = hivemind.utils.combine_from_streaming(chunks4[:1])
  191. combined_incomplete3 = hivemind.utils.combine_from_streaming(chunks4[:-1])
  192. for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
  193. with pytest.raises(RuntimeError):
  194. hivemind.deserialize_torch_tensor(combined)
  195. # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceProtocol
  196. def test_generic_data_classes():
  197. from hivemind.utils import ValueWithExpiration, HeapEntry, DHTExpiration
  198. value_with_exp = ValueWithExpiration(value="string_value", expiration_time=DHTExpiration(10))
  199. assert value_with_exp.value == "string_value" and value_with_exp.expiration_time == DHTExpiration(10)
  200. heap_entry = HeapEntry(expiration_time=DHTExpiration(10), key="string_value")
  201. assert heap_entry.key == "string_value" and heap_entry.expiration_time == DHTExpiration(10)
  202. sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)])
  203. sorted_heap_entry = sorted([HeapEntry(expiration_time=DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
  204. assert all([heap_entry.expiration_time == value for heap_entry, value in zip(sorted_heap_entry, sorted_expirations)])