test_p2p_daemon.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import asyncio
  2. import multiprocessing as mp
  3. import subprocess
  4. import numpy as np
  5. import pytest
  6. import hivemind.p2p
  7. from hivemind.p2p import P2P
  8. RUNNING = 'running'
  9. NOT_RUNNING = 'not running'
  10. CHECK_PID_CMD = '''
  11. if ps -p {0} > /dev/null;
  12. then
  13. echo "{1}"
  14. else
  15. echo "{2}"
  16. fi
  17. '''
  18. def is_process_running(pid: int) -> bool:
  19. cmd = CHECK_PID_CMD.format(pid, RUNNING, NOT_RUNNING)
  20. return subprocess.check_output(cmd, shell=True).decode('utf-8').strip() == RUNNING
  21. @pytest.mark.asyncio
  22. async def test_daemon_killed_on_del():
  23. p2p_daemon = await P2P.create()
  24. child_pid = p2p_daemon._child.pid
  25. assert is_process_running(child_pid)
  26. p2p_daemon.__del__()
  27. assert not is_process_running(child_pid)
  28. def handle_square(x):
  29. return x ** 2
  30. def handle_add(args):
  31. result = args[0]
  32. for i in range(1, len(args)):
  33. result = result + args[i]
  34. return result
  35. @pytest.mark.parametrize(
  36. "test_input,handle",
  37. [
  38. pytest.param(10, handle_square, id="square_integer"),
  39. pytest.param((1, 2), handle_add, id="add_integers"),
  40. pytest.param(([1, 2, 3], [12, 13]), handle_add, id="add_lists"),
  41. pytest.param(2, lambda x: x ** 3, id="lambda")
  42. ]
  43. )
  44. @pytest.mark.asyncio
  45. async def test_call_peer_single_process(test_input, handle, handler_name="handle"):
  46. server = await P2P.create()
  47. server_pid = server._child.pid
  48. await server.add_stream_handler(handler_name, handle)
  49. assert is_process_running(server_pid)
  50. client = await P2P.create()
  51. client_pid = client._child.pid
  52. assert is_process_running(client_pid)
  53. await asyncio.sleep(1)
  54. result = await client.call_peer_handler(server.id, handler_name, test_input)
  55. assert result == handle(test_input)
  56. await server.stop_listening()
  57. server.__del__()
  58. assert not is_process_running(server_pid)
  59. client.__del__()
  60. assert not is_process_running(client_pid)
  61. @pytest.mark.asyncio
  62. async def test_call_peer_different_processes():
  63. handler_name = "square"
  64. test_input = np.random.randn(2, 3)
  65. server_side, client_side = mp.Pipe()
  66. response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
  67. response_received.value = 0
  68. async def run_server():
  69. server = await P2P.create()
  70. server_pid = server._child.pid
  71. await server.add_stream_handler(handler_name, handle_square)
  72. assert is_process_running(server_pid)
  73. server_side.send(server.id)
  74. while response_received.value == 0:
  75. await asyncio.sleep(0.5)
  76. await server.stop_listening()
  77. server.__del__()
  78. assert not is_process_running(server_pid)
  79. def server_target():
  80. asyncio.run(run_server())
  81. proc = mp.Process(target=server_target)
  82. proc.start()
  83. client = await P2P.create()
  84. client_pid = client._child.pid
  85. assert is_process_running(client_pid)
  86. await asyncio.sleep(1)
  87. peer_id = client_side.recv()
  88. result = await client.call_peer_handler(peer_id, handler_name, test_input)
  89. assert np.allclose(result, handle_square(test_input))
  90. response_received.value = 1
  91. client.__del__()
  92. assert not is_process_running(client_pid)
  93. proc.join()
  94. @pytest.mark.parametrize(
  95. "test_input,handle",
  96. [
  97. pytest.param(np.random.randn(2, 3), handle_square, id="square"),
  98. pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, id="add"),
  99. ]
  100. )
  101. @pytest.mark.asyncio
  102. async def test_call_peer_numpy(test_input, handle, handler_name="handle"):
  103. server = await P2P.create()
  104. await server.add_stream_handler(handler_name, handle)
  105. client = await P2P.create()
  106. await asyncio.sleep(1)
  107. result = await client.call_peer_handler(server.id, handler_name, test_input)
  108. assert np.allclose(result, handle(test_input))
  109. @pytest.mark.asyncio
  110. async def test_call_peer_error(handler_name="handle"):
  111. server = await P2P.create()
  112. await server.add_stream_handler(handler_name, handle_add)
  113. client = await P2P.create()
  114. await asyncio.sleep(1)
  115. result = await client.call_peer_handler(server.id, handler_name,
  116. [np.zeros((2, 3)), np.zeros((3, 2))])
  117. assert type(result) == ValueError