test_p2p_daemon.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import asyncio
  2. import multiprocessing as mp
  3. import subprocess
  4. from libp2p.peer.id import ID
  5. import numpy as np
  6. import pytest
  7. from hivemind.p2p import P2P
  8. from hivemind.proto import dht_pb2
  9. RUNNING = 'running'
  10. NOT_RUNNING = 'not running'
  11. CHECK_PID_CMD = '''
  12. if ps -p {0} > /dev/null;
  13. then
  14. echo "{1}"
  15. else
  16. echo "{2}"
  17. fi
  18. '''
  19. def is_process_running(pid: int) -> bool:
  20. cmd = CHECK_PID_CMD.format(pid, RUNNING, NOT_RUNNING)
  21. return subprocess.check_output(cmd, shell=True).decode('utf-8').strip() == RUNNING
  22. @pytest.mark.asyncio
  23. async def test_daemon_killed_on_del():
  24. p2p_daemon = await P2P.create()
  25. child_pid = p2p_daemon._child.pid
  26. assert is_process_running(child_pid)
  27. p2p_daemon.__del__()
  28. assert not is_process_running(child_pid)
  29. def handle_square(x):
  30. return x ** 2
  31. def handle_add(args):
  32. result = args[0]
  33. for i in range(1, len(args)):
  34. result = result + args[i]
  35. return result
  36. @pytest.mark.parametrize(
  37. 'should_cancel', [True, False]
  38. )
  39. @pytest.mark.asyncio
  40. async def test_call_unary_handler(should_cancel, handle_name="handle"):
  41. handler_cancelled = False
  42. async def ping_handler(request, context):
  43. try:
  44. await asyncio.sleep(2)
  45. except asyncio.CancelledError:
  46. nonlocal handler_cancelled
  47. handler_cancelled = True
  48. return dht_pb2.PingResponse(
  49. peer=dht_pb2.NodeInfo(
  50. node_id=context.ours_id.encode(), rpc_port=context.ours_port),
  51. sender_endpoint=context.handle_name, available=True)
  52. server = await P2P.create()
  53. server_pid = server._child.pid
  54. await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
  55. dht_pb2.PingResponse)
  56. assert is_process_running(server_pid)
  57. client = await P2P.create()
  58. client_pid = client._child.pid
  59. assert is_process_running(client_pid)
  60. ping_request = dht_pb2.PingRequest(
  61. peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
  62. validate=True)
  63. expected_response = dht_pb2.PingResponse(
  64. peer=dht_pb2.NodeInfo(node_id=server.id.encode(), rpc_port=server._host_port),
  65. sender_endpoint=handle_name, available=True)
  66. await asyncio.sleep(1)
  67. libp2p_server_id = ID.from_base58(server.id)
  68. stream_info, stream = await client._client.stream_open(libp2p_server_id, (handle_name,))
  69. await P2P.send_raw_data(ping_request.SerializeToString(), stream)
  70. if should_cancel:
  71. await stream.close()
  72. await asyncio.sleep(1)
  73. assert handler_cancelled
  74. else:
  75. result = await P2P.receive_protobuf(dht_pb2.PingResponse, stream)
  76. assert result == expected_response
  77. assert not handler_cancelled
  78. await server.stop_listening()
  79. server.__del__()
  80. assert not is_process_running(server_pid)
  81. client.__del__()
  82. assert not is_process_running(client_pid)
  83. @pytest.mark.parametrize(
  84. "test_input,handle",
  85. [
  86. pytest.param(10, handle_square, id="square_integer"),
  87. pytest.param((1, 2), handle_add, id="add_integers"),
  88. pytest.param(([1, 2, 3], [12, 13]), handle_add, id="add_lists"),
  89. pytest.param(2, lambda x: x ** 3, id="lambda")
  90. ]
  91. )
  92. @pytest.mark.asyncio
  93. async def test_call_peer_single_process(test_input, handle, handler_name="handle"):
  94. server = await P2P.create()
  95. server_pid = server._child.pid
  96. await server.add_stream_handler(handler_name, handle)
  97. assert is_process_running(server_pid)
  98. client = await P2P.create()
  99. client_pid = client._child.pid
  100. assert is_process_running(client_pid)
  101. await asyncio.sleep(1)
  102. result = await client.call_peer_handler(server.id, handler_name, test_input)
  103. assert result == handle(test_input)
  104. await server.stop_listening()
  105. server.__del__()
  106. assert not is_process_running(server_pid)
  107. client.__del__()
  108. assert not is_process_running(client_pid)
  109. @pytest.mark.asyncio
  110. async def test_call_peer_different_processes():
  111. handler_name = "square"
  112. test_input = np.random.randn(2, 3)
  113. server_side, client_side = mp.Pipe()
  114. response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
  115. response_received.value = 0
  116. async def run_server():
  117. server = await P2P.create()
  118. server_pid = server._child.pid
  119. await server.add_stream_handler(handler_name, handle_square)
  120. assert is_process_running(server_pid)
  121. server_side.send(server.id)
  122. while response_received.value == 0:
  123. await asyncio.sleep(0.5)
  124. await server.stop_listening()
  125. server.__del__()
  126. assert not is_process_running(server_pid)
  127. def server_target():
  128. asyncio.run(run_server())
  129. proc = mp.Process(target=server_target)
  130. proc.start()
  131. client = await P2P.create()
  132. client_pid = client._child.pid
  133. assert is_process_running(client_pid)
  134. await asyncio.sleep(1)
  135. peer_id = client_side.recv()
  136. result = await client.call_peer_handler(peer_id, handler_name, test_input)
  137. assert np.allclose(result, handle_square(test_input))
  138. response_received.value = 1
  139. client.__del__()
  140. assert not is_process_running(client_pid)
  141. proc.join()
  142. @pytest.mark.parametrize(
  143. "test_input,handle",
  144. [
  145. pytest.param(np.random.randn(2, 3), handle_square, id="square"),
  146. pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, id="add"),
  147. ]
  148. )
  149. @pytest.mark.asyncio
  150. async def test_call_peer_numpy(test_input, handle, handler_name="handle"):
  151. server = await P2P.create()
  152. await server.add_stream_handler(handler_name, handle)
  153. client = await P2P.create()
  154. await asyncio.sleep(1)
  155. result = await client.call_peer_handler(server.id, handler_name, test_input)
  156. assert np.allclose(result, handle(test_input))
  157. @pytest.mark.asyncio
  158. async def test_call_peer_error(handler_name="handle"):
  159. server = await P2P.create()
  160. await server.add_stream_handler(handler_name, handle_add)
  161. client = await P2P.create()
  162. await asyncio.sleep(1)
  163. result = await client.call_peer_handler(server.id, handler_name,
  164. [np.zeros((2, 3)), np.zeros((3, 2))])
  165. assert type(result) == ValueError