test_dht.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import asyncio
  2. import random
  3. import time
  4. import pytest
  5. import hivemind
  6. from hivemind import LOCALHOST, strip_port
  7. @pytest.mark.skip
  8. @pytest.mark.forked
  9. def test_get_store():
  10. #TODO this raises: Failed to initialize p2p daemon: [Errno 111] Connection refused
  11. peers = []
  12. for i in range(10):
  13. neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
  14. peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
  15. node1, node2 = random.sample(peers, 2)
  16. assert node1.store('key1', 'value1', expiration_time=hivemind.get_dht_time() + 30)
  17. assert node1.get('key1').value == 'value1'
  18. assert node2.get('key1').value == 'value1'
  19. assert node2.get('key2') is None
  20. future = node1.get('foo', return_future=True)
  21. assert future.result() is None
  22. future = node1.get('foo', return_future=True)
  23. future.cancel()
  24. assert node2.store('key1', 123, expiration_time=hivemind.get_dht_time() + 31)
  25. assert node2.store('key2', 456, expiration_time=hivemind.get_dht_time() + 32)
  26. assert node1.get('key1', latest=True).value == 123
  27. assert node1.get('key2').value == 456
  28. assert node1.store('key2', subkey='subkey1', value=789, expiration_time=hivemind.get_dht_time() + 32)
  29. assert node2.store('key2', subkey='subkey2', value='pew', expiration_time=hivemind.get_dht_time() + 32)
  30. found_dict = node1.get('key2', latest=True).value
  31. assert isinstance(found_dict, dict) and len(found_dict) == 2
  32. assert found_dict['subkey1'].value == 789 and found_dict['subkey2'].value == 'pew'
  33. for peer in peers:
  34. peer.shutdown()
  35. async def dummy_dht_coro(self, node):
  36. return 'pew'
  37. async def dummy_dht_coro_error(self, node):
  38. raise ValueError("Oops, i did it again...")
  39. async def dummy_dht_coro_stateful(self, node):
  40. self._x_dummy = getattr(self, '_x_dummy', 123) + 1
  41. return self._x_dummy
  42. async def dummy_dht_coro_long(self, node):
  43. await asyncio.sleep(0.25)
  44. return self._x_dummy ** 2
  45. async def dummy_dht_coro_for_cancel(self, node):
  46. self._x_dummy = -100
  47. await asyncio.sleep(0.5)
  48. self._x_dummy = 999
  49. @pytest.mark.forked
  50. def test_run_coroutine():
  51. dht = hivemind.DHT(start=True)
  52. assert dht.run_coroutine(dummy_dht_coro) == 'pew'
  53. with pytest.raises(ValueError):
  54. res = dht.run_coroutine(dummy_dht_coro_error)
  55. bg_task = dht.run_coroutine(dummy_dht_coro_long, return_future=True)
  56. assert dht.run_coroutine(dummy_dht_coro_stateful) == 124
  57. assert dht.run_coroutine(dummy_dht_coro_stateful) == 125
  58. assert dht.run_coroutine(dummy_dht_coro_stateful) == 126
  59. assert not hasattr(dht, '_x_dummy')
  60. assert bg_task.result() == 126 ** 2
  61. future = dht.run_coroutine(dummy_dht_coro_for_cancel, return_future=True)
  62. time.sleep(0.25)
  63. future.cancel()
  64. assert dht.run_coroutine(dummy_dht_coro_stateful) == -99
  65. @pytest.mark.skip
  66. @pytest.mark.forked
  67. def test_dht_get_address(addr=LOCALHOST, dummy_endpoint='123.45.67.89:*'):
  68. node1 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
  69. node2 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node1.port}"])
  70. node3 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node2.port}"])
  71. assert addr in node3.get_visible_address(num_peers=2)
  72. node4 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
  73. with pytest.raises(ValueError):
  74. node4.get_visible_address()
  75. assert node4.get_visible_address(peers=[f'{addr}:{node1.port}']).endswith(addr)
  76. node5 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", endpoint=f"{dummy_endpoint}")
  77. assert node5.get_visible_address() == strip_port(dummy_endpoint)