test_util_modules.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import asyncio
  2. import torch
  3. import pytest
  4. import hivemind
  5. from hivemind.proto.dht_pb2_grpc import DHTStub
  6. from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
  7. from concurrent.futures import CancelledError
  8. def test_mpfuture_result():
  9. f1, f2 = hivemind.MPFuture.make_pair()
  10. f1.set_result(321)
  11. assert f2.result() == 321
  12. assert f1.result() == 321
  13. for future in [f1, f2]:
  14. with pytest.raises(RuntimeError):
  15. future.set_result(123)
  16. with pytest.raises(RuntimeError):
  17. future.set_exception(ValueError())
  18. assert future.cancel() is False
  19. assert future.done() and not future.running() and not future.cancelled()
  20. f1, f2 = hivemind.MPFuture.make_pair()
  21. with pytest.raises(TimeoutError):
  22. f1.result(timeout=1e-3)
  23. f2.set_result(['abacaba', 123])
  24. assert f1.result() == ['abacaba', 123]
  25. def test_mpfuture_exception():
  26. f1, f2 = hivemind.MPFuture.make_pair()
  27. with pytest.raises(TimeoutError):
  28. f1.exception(timeout=1e-3)
  29. f2.set_exception(NotImplementedError())
  30. for future in [f1, f2]:
  31. assert isinstance(future.exception(), NotImplementedError)
  32. with pytest.raises(NotImplementedError):
  33. future.result()
  34. assert future.cancel() is False
  35. assert future.done() and not future.running() and not future.cancelled()
  36. def test_mpfuture_cancel():
  37. f1, f2 = hivemind.MPFuture.make_pair()
  38. assert not f2.cancelled()
  39. f1.cancel()
  40. for future in [f1, f2]:
  41. with pytest.raises(CancelledError):
  42. future.result()
  43. with pytest.raises(CancelledError):
  44. future.exception()
  45. with pytest.raises(RuntimeError):
  46. future.set_result(123)
  47. with pytest.raises(RuntimeError):
  48. future.set_exception(NotImplementedError)
  49. assert future.cancelled() and future.done() and not future.running()
  50. def test_mpfuture_status():
  51. f1, f2 = hivemind.MPFuture.make_pair()
  52. assert f1.set_running_or_notify_cancel() is True
  53. for future in [f1, f2]:
  54. assert future.running() and not future.done() and not future.cancelled()
  55. with pytest.raises(RuntimeError):
  56. future.set_running_or_notify_cancel()
  57. f2.cancel()
  58. for future in [f1, f2]:
  59. assert not future.running() and future.done() and future.cancelled()
  60. assert future.set_running_or_notify_cancel() is False
  61. f1, f2 = hivemind.MPFuture.make_pair()
  62. f1.cancel()
  63. for future in [f1, f2]:
  64. assert future.set_running_or_notify_cancel() is False
  65. @pytest.mark.asyncio
  66. async def test_await_mpfuture():
  67. # await result
  68. f1, f2 = hivemind.MPFuture.make_pair()
  69. async def wait_and_assign():
  70. assert f2.set_running_or_notify_cancel() is True
  71. await asyncio.sleep(0.1)
  72. f2.set_result((123, 'ololo'))
  73. asyncio.create_task(wait_and_assign())
  74. for future in [f1, f2]:
  75. res = await future
  76. assert res == (123, 'ololo')
  77. # await cancel
  78. f1, f2 = hivemind.MPFuture.make_pair()
  79. async def wait_and_cancel():
  80. await asyncio.sleep(0.1)
  81. f1.cancel()
  82. asyncio.create_task(wait_and_cancel())
  83. for future in [f1, f2]:
  84. with pytest.raises(CancelledError):
  85. await future
  86. # await exception
  87. f1, f2 = hivemind.MPFuture.make_pair()
  88. async def wait_and_raise():
  89. await asyncio.sleep(0.1)
  90. f1.set_exception(SystemError())
  91. asyncio.create_task(wait_and_raise())
  92. for future in [f1, f2]:
  93. with pytest.raises(SystemError):
  94. await future
  95. def test_vector_compression(size=(128, 128, 64), alpha=5e-08):
  96. torch.manual_seed(0)
  97. from hivemind.proto.runtime_pb2 import CompressionType
  98. from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor
  99. X = torch.randn(*size)
  100. assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
  101. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_LAST_AXIS_FLOAT16))-X
  102. assert error.square().mean() < alpha
  103. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
  104. assert error.square().mean() < alpha
  105. @pytest.mark.forked
  106. @pytest.mark.asyncio
  107. async def test_channel_cache():
  108. hivemind.ChannelCache.MAXIMUM_CHANNELS = 3
  109. hivemind.ChannelCache.EVICTION_PERIOD_SECONDS = 0.1
  110. c1 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
  111. c2 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
  112. c3 = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
  113. c3_again = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
  114. c1_again = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
  115. c4 = hivemind.ChannelCache.get_stub('localhost:1339', DHTStub, aio=True)
  116. c2_anew = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
  117. c1_yetagain = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
  118. await asyncio.sleep(0.2)
  119. c1_anew = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
  120. c1_anew_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
  121. c1_otherstub = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=ConnectionHandlerStub)
  122. await asyncio.sleep(0.05)
  123. c1_otherstub_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False,
  124. stub_type=ConnectionHandlerStub)
  125. all_channels = [c1, c2, c3, c4, c3_again, c1_again, c2_anew, c1_yetagain, c1_anew, c1_anew_again, c1_otherstub]
  126. assert all(isinstance(c, DHTStub) for c in all_channels[:-1])
  127. assert isinstance(all_channels[-1], ConnectionHandlerStub)
  128. assert 'aio' in repr(c2.rpc_find)
  129. assert 'aio' not in repr(c1.rpc_find)
  130. duplicates = {(c1, c1_again), (c1, c1_yetagain), (c1_again, c1_yetagain), (c3, c3_again),
  131. (c1_anew, c1_anew_again), (c1_otherstub, c1_otherstub_again)}
  132. for i in range(len(all_channels)):
  133. for j in range(i + 1, len(all_channels)):
  134. ci, cj = all_channels[i], all_channels[j]
  135. assert (ci is cj) == ((ci, cj) in duplicates), (i, j)