test_p2p_daemon.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import asyncio
  2. import multiprocessing as mp
  3. import subprocess
  4. from hivemind.p2p.p2p_daemon_bindings.datastructures import ID
  5. import numpy as np
  6. import pytest
  7. from hivemind.p2p import P2P
  8. from hivemind.proto import dht_pb2
  9. RUNNING = 'running'
  10. NOT_RUNNING = 'not running'
  11. CHECK_PID_CMD = '''
  12. if ps -p {0} > /dev/null;
  13. then
  14. echo "{1}"
  15. else
  16. echo "{2}"
  17. fi
  18. '''
  19. def is_process_running(pid: int) -> bool:
  20. cmd = CHECK_PID_CMD.format(pid, RUNNING, NOT_RUNNING)
  21. return subprocess.check_output(cmd, shell=True).decode('utf-8').strip() == RUNNING
  22. @pytest.mark.asyncio
  23. async def test_daemon_killed_on_del():
  24. p2p_daemon = await P2P.create()
  25. child_pid = p2p_daemon._child.pid
  26. assert is_process_running(child_pid)
  27. p2p_daemon.__del__()
  28. assert not is_process_running(child_pid)
  29. def handle_square(x):
  30. return x ** 2
  31. def handle_add(args):
  32. result = args[0]
  33. for i in range(1, len(args)):
  34. result = result + args[i]
  35. return result
  36. @pytest.mark.parametrize(
  37. 'should_cancel', [True, False]
  38. )
  39. @pytest.mark.asyncio
  40. async def test_call_unary_handler(should_cancel, handle_name="handle"):
  41. handler_cancelled = False
  42. async def ping_handler(request, context):
  43. try:
  44. await asyncio.sleep(2)
  45. except asyncio.CancelledError:
  46. nonlocal handler_cancelled
  47. handler_cancelled = True
  48. return dht_pb2.PingResponse(
  49. peer=dht_pb2.NodeInfo(
  50. node_id=context.ours_id.encode(), rpc_port=context.ours_port),
  51. sender_endpoint=context.handle_name, available=True)
  52. server = await P2P.create()
  53. server_pid = server._child.pid
  54. await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
  55. dht_pb2.PingResponse)
  56. assert is_process_running(server_pid)
  57. client = await P2P.create()
  58. client_pid = client._child.pid
  59. assert is_process_running(client_pid)
  60. ping_request = dht_pb2.PingRequest(
  61. peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
  62. validate=True)
  63. expected_response = dht_pb2.PingResponse(
  64. peer=dht_pb2.NodeInfo(node_id=server.id.encode(), rpc_port=server._host_port),
  65. sender_endpoint=handle_name, available=True)
  66. await asyncio.sleep(1)
  67. libp2p_server_id = ID.from_base58(server.id)
  68. stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,))
  69. await P2P.send_raw_data(ping_request.SerializeToString(), writer)
  70. if should_cancel:
  71. writer.close()
  72. await asyncio.sleep(1)
  73. assert handler_cancelled
  74. else:
  75. result = await P2P.receive_protobuf(dht_pb2.PingResponse, reader)
  76. assert result == expected_response
  77. assert not handler_cancelled
  78. await server.stop_listening()
  79. server.__del__()
  80. assert not is_process_running(server_pid)
  81. client.__del__()
  82. assert not is_process_running(client_pid)
  83. @pytest.mark.parametrize(
  84. "test_input,handle",
  85. [
  86. pytest.param(10, handle_square, id="square_integer"),
  87. pytest.param((1, 2), handle_add, id="add_integers"),
  88. pytest.param(([1, 2, 3], [12, 13]), handle_add, id="add_lists"),
  89. pytest.param(2, lambda x: x ** 3, id="lambda")
  90. ]
  91. )
  92. @pytest.mark.asyncio
  93. async def test_call_peer_single_process(test_input, handle, handler_name="handle"):
  94. server = await P2P.create()
  95. server_pid = server._child.pid
  96. await server.add_stream_handler(handler_name, handle)
  97. assert is_process_running(server_pid)
  98. client = await P2P.create()
  99. client_pid = client._child.pid
  100. assert is_process_running(client_pid)
  101. await asyncio.sleep(1)
  102. result = await client.call_peer_handler(server.id, handler_name, test_input)
  103. assert result == handle(test_input)
  104. await server.stop_listening()
  105. server.__del__()
  106. assert not is_process_running(server_pid)
  107. client.__del__()
  108. assert not is_process_running(client_pid)
  109. async def run_server(handler_name, server_side, client_side, response_received):
  110. server = await P2P.create()
  111. server_pid = server._child.pid
  112. await server.add_stream_handler(handler_name, handle_square)
  113. assert is_process_running(server_pid)
  114. server_side.send(server.id)
  115. while response_received.value == 0:
  116. await asyncio.sleep(0.5)
  117. await server.stop_listening()
  118. server.__del__()
  119. assert not is_process_running(server_pid)
  120. def server_target(handler_name, server_side, client_side, response_received):
  121. asyncio.run(run_server(handler_name, server_side, client_side, response_received))
  122. @pytest.mark.asyncio
  123. async def test_call_peer_different_processes():
  124. handler_name = "square"
  125. test_input = np.random.randn(2, 3)
  126. server_side, client_side = mp.Pipe()
  127. response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
  128. response_received.value = 0
  129. proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
  130. proc.start()
  131. client = await P2P.create()
  132. client_pid = client._child.pid
  133. assert is_process_running(client_pid)
  134. await asyncio.sleep(1)
  135. peer_id = client_side.recv()
  136. result = await client.call_peer_handler(peer_id, handler_name, test_input)
  137. assert np.allclose(result, handle_square(test_input))
  138. response_received.value = 1
  139. client.__del__()
  140. assert not is_process_running(client_pid)
  141. proc.join()
  142. @pytest.mark.parametrize(
  143. "test_input,handle",
  144. [
  145. pytest.param(np.random.randn(2, 3), handle_square, id="square"),
  146. pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, id="add"),
  147. ]
  148. )
  149. @pytest.mark.asyncio
  150. async def test_call_peer_numpy(test_input, handle, handler_name="handle"):
  151. server = await P2P.create()
  152. await server.add_stream_handler(handler_name, handle)
  153. client = await P2P.create()
  154. await asyncio.sleep(1)
  155. result = await client.call_peer_handler(server.id, handler_name, test_input)
  156. assert np.allclose(result, handle(test_input))
  157. @pytest.mark.asyncio
  158. async def test_call_peer_error(handler_name="handle"):
  159. server = await P2P.create()
  160. await server.add_stream_handler(handler_name, handle_add)
  161. client = await P2P.create()
  162. await asyncio.sleep(1)
  163. result = await client.call_peer_handler(server.id, handler_name,
  164. [np.zeros((2, 3)), np.zeros((3, 2))])
  165. assert type(result) == ValueError