test_util_modules.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import asyncio
  2. import torch
  3. import pytest
  4. import hivemind
  5. from concurrent.futures import CancelledError
  6. def test_mpfuture_result():
  7. f1, f2 = hivemind.MPFuture.make_pair()
  8. f1.set_result(321)
  9. assert f2.result() == 321
  10. assert f1.result() == 321
  11. for future in [f1, f2]:
  12. with pytest.raises(RuntimeError):
  13. future.set_result(123)
  14. with pytest.raises(RuntimeError):
  15. future.set_exception(ValueError())
  16. assert future.cancel() is False
  17. assert future.done() and not future.running() and not future.cancelled()
  18. f1, f2 = hivemind.MPFuture.make_pair()
  19. with pytest.raises(TimeoutError):
  20. f1.result(timeout=1e-3)
  21. f2.set_result(['abacaba', 123])
  22. assert f1.result() == ['abacaba', 123]
  23. def test_mpfuture_exception():
  24. f1, f2 = hivemind.MPFuture.make_pair()
  25. with pytest.raises(TimeoutError):
  26. f1.exception(timeout=1e-3)
  27. f2.set_exception(NotImplementedError())
  28. for future in [f1, f2]:
  29. assert isinstance(future.exception(), NotImplementedError)
  30. with pytest.raises(NotImplementedError):
  31. future.result()
  32. assert future.cancel() is False
  33. assert future.done() and not future.running() and not future.cancelled()
  34. def test_mpfuture_cancel():
  35. f1, f2 = hivemind.MPFuture.make_pair()
  36. assert not f2.cancelled()
  37. f1.cancel()
  38. for future in [f1, f2]:
  39. with pytest.raises(CancelledError):
  40. future.result()
  41. with pytest.raises(CancelledError):
  42. future.exception()
  43. with pytest.raises(RuntimeError):
  44. future.set_result(123)
  45. with pytest.raises(RuntimeError):
  46. future.set_exception(NotImplementedError)
  47. assert future.cancelled() and future.done() and not future.running()
  48. def test_mpfuture_status():
  49. f1, f2 = hivemind.MPFuture.make_pair()
  50. assert f1.set_running_or_notify_cancel() is True
  51. for future in [f1, f2]:
  52. assert future.running() and not future.done() and not future.cancelled()
  53. with pytest.raises(RuntimeError):
  54. future.set_running_or_notify_cancel()
  55. f2.cancel()
  56. for future in [f1, f2]:
  57. assert not future.running() and future.done() and future.cancelled()
  58. assert future.set_running_or_notify_cancel() is False
  59. f1, f2 = hivemind.MPFuture.make_pair()
  60. f1.cancel()
  61. for future in [f1, f2]:
  62. assert future.set_running_or_notify_cancel() is False
  63. @pytest.mark.asyncio
  64. async def test_await_mpfuture():
  65. # await result
  66. f1, f2 = hivemind.MPFuture.make_pair()
  67. async def wait_and_assign():
  68. assert f2.set_running_or_notify_cancel() is True
  69. await asyncio.sleep(0.1)
  70. f2.set_result((123, 'ololo'))
  71. asyncio.create_task(wait_and_assign())
  72. for future in [f1, f2]:
  73. res = await future
  74. assert res == (123, 'ololo')
  75. # await cancel
  76. f1, f2 = hivemind.MPFuture.make_pair()
  77. async def wait_and_cancel():
  78. await asyncio.sleep(0.1)
  79. f1.cancel()
  80. asyncio.create_task(wait_and_cancel())
  81. for future in [f1, f2]:
  82. with pytest.raises(CancelledError):
  83. await future
  84. # await exception
  85. f1, f2 = hivemind.MPFuture.make_pair()
  86. async def wait_and_raise():
  87. await asyncio.sleep(0.1)
  88. f1.set_exception(SystemError())
  89. asyncio.create_task(wait_and_raise())
  90. for future in [f1, f2]:
  91. with pytest.raises(SystemError):
  92. await future
  93. def test_vector_compression(size=(128, 128, 64), alpha=5e-08):
  94. torch.manual_seed(0)
  95. from hivemind.proto.runtime_pb2 import CompressionType
  96. from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor
  97. X = torch.randn(*size)
  98. assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
  99. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_LAST_AXIS_FLOAT16))-X
  100. assert error.square().mean() < alpha
  101. error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
  102. assert error.square().mean() < alpha