test_p2p_daemon.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import asyncio
  2. import multiprocessing as mp
  3. import subprocess
  4. from functools import partial
  5. from hivemind.p2p.p2p_daemon_bindings.datastructures import ID
  6. import numpy as np
  7. import pytest
  8. from hivemind.p2p import P2P
  9. from hivemind.proto import dht_pb2
  10. RUNNING = 'running'
  11. NOT_RUNNING = 'not running'
  12. CHECK_PID_CMD = '''
  13. if ps -p {0} > /dev/null;
  14. then
  15. echo "{1}"
  16. else
  17. echo "{2}"
  18. fi
  19. '''
  20. def is_process_running(pid: int) -> bool:
  21. cmd = CHECK_PID_CMD.format(pid, RUNNING, NOT_RUNNING)
  22. return subprocess.check_output(cmd, shell=True).decode('utf-8').strip() == RUNNING
  23. async def replicate_if_needed(p2p: P2P, replicate: bool):
  24. return await P2P.replicate(p2p._daemon_listen_port, p2p._host_port) if replicate else p2p
  25. @pytest.mark.asyncio
  26. async def test_daemon_killed_on_del():
  27. p2p_daemon = await P2P.create()
  28. child_pid = p2p_daemon._child.pid
  29. assert is_process_running(child_pid)
  30. await p2p_daemon.shutdown()
  31. assert not is_process_running(child_pid)
  32. @pytest.mark.asyncio
  33. async def test_daemon_replica_does_not_affect_primary():
  34. p2p_daemon = await P2P.create()
  35. p2p_replica = await P2P.replicate(p2p_daemon._daemon_listen_port, p2p_daemon._host_port)
  36. child_pid = p2p_daemon._child.pid
  37. assert is_process_running(child_pid)
  38. await p2p_replica.shutdown()
  39. assert is_process_running(child_pid)
  40. await p2p_daemon.shutdown()
  41. assert not is_process_running(child_pid)
  42. def handle_square(x):
  43. return x ** 2
  44. def handle_add(args):
  45. result = args[0]
  46. for i in range(1, len(args)):
  47. result = result + args[i]
  48. return result
  49. @pytest.mark.parametrize(
  50. 'should_cancel,replicate', [
  51. (True, False),
  52. (True, True),
  53. (False, False),
  54. (False, True),
  55. ]
  56. )
  57. @pytest.mark.asyncio
  58. async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
  59. handler_cancelled = False
  60. async def ping_handler(request, context):
  61. try:
  62. await asyncio.sleep(2)
  63. except asyncio.CancelledError:
  64. nonlocal handler_cancelled
  65. handler_cancelled = True
  66. return dht_pb2.PingResponse(
  67. peer=dht_pb2.NodeInfo(
  68. node_id=context.ours_id.to_bytes(), rpc_port=context.ours_port),
  69. sender_endpoint=context.peer(), available=True)
  70. server_primary = await P2P.create()
  71. server = await replicate_if_needed(server_primary, replicate)
  72. server_pid = server_primary._child.pid
  73. await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
  74. dht_pb2.PingResponse)
  75. assert is_process_running(server_pid)
  76. client_primary = await P2P.create()
  77. client = await replicate_if_needed(client_primary, replicate)
  78. client_pid = client_primary._child.pid
  79. assert is_process_running(client_pid)
  80. ping_request = dht_pb2.PingRequest(
  81. peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes(), rpc_port=client._host_port),
  82. validate=True)
  83. expected_response = dht_pb2.PingResponse(
  84. peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes(), rpc_port=server._host_port),
  85. sender_endpoint=client.endpoint, available=True)
  86. await asyncio.sleep(1)
  87. if should_cancel:
  88. stream_info, reader, writer = await client._client.stream_open(
  89. server.id, (handle_name,))
  90. await P2P.send_raw_data(ping_request.SerializeToString(), writer)
  91. writer.close()
  92. await asyncio.sleep(1)
  93. assert handler_cancelled
  94. else:
  95. result = await client.call_unary_handler(server.endpoint, handle_name, ping_request,
  96. dht_pb2.PingResponse)
  97. assert result == expected_response
  98. assert not handler_cancelled
  99. await server.stop_listening()
  100. await server_primary.shutdown()
  101. assert not is_process_running(server_pid)
  102. await client_primary.shutdown()
  103. assert not is_process_running(client_pid)
  104. @pytest.mark.parametrize(
  105. "test_input,handle",
  106. [
  107. pytest.param(10, handle_square, id="square_integer"),
  108. pytest.param((1, 2), handle_add, id="add_integers"),
  109. pytest.param(([1, 2, 3], [12, 13]), handle_add, id="add_lists"),
  110. pytest.param(2, lambda x: x ** 3, id="lambda")
  111. ]
  112. )
  113. @pytest.mark.asyncio
  114. async def test_call_peer_single_process(test_input, handle, handler_name="handle"):
  115. server = await P2P.create()
  116. server_pid = server._child.pid
  117. await server.add_stream_handler(handler_name, handle)
  118. assert is_process_running(server_pid)
  119. client = await P2P.create()
  120. client_pid = client._child.pid
  121. assert is_process_running(client_pid)
  122. # await asyncio.sleep(1)
  123. result = await client.call_peer_handler(server.endpoint, handler_name, test_input)
  124. assert result == handle(test_input)
  125. await server.shutdown()
  126. assert not is_process_running(server_pid)
  127. await client.shutdown()
  128. assert not is_process_running(client_pid)
  129. async def run_server(handler_name, server_side, client_side, response_received):
  130. server = await P2P.create()
  131. server_pid = server._child.pid
  132. await server.add_stream_handler(handler_name, handle_square)
  133. assert is_process_running(server_pid)
  134. server_side.send(server.id)
  135. while response_received.value == 0:
  136. await asyncio.sleep(0.5)
  137. await server.stop_listening()
  138. await server.shutdown()
  139. assert not is_process_running(server_pid)
  140. def server_target(handler_name, server_side, client_side, response_received):
  141. asyncio.run(run_server(handler_name, server_side, client_side, response_received))
  142. @pytest.mark.asyncio
  143. async def test_call_peer_different_processes():
  144. handler_name = "square"
  145. test_input = np.random.randn(2, 3)
  146. server_side, client_side = mp.Pipe()
  147. response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
  148. response_received.value = 0
  149. proc = mp.Process(target=server_target,
  150. args=(handler_name, server_side, client_side, response_received))
  151. proc.start()
  152. client = await P2P.create()
  153. client_pid = client._child.pid
  154. assert is_process_running(client_pid)
  155. # await asyncio.sleep(1)
  156. peer_id = client_side.recv()
  157. result = await client.call_peer_handler(peer_id.to_base58(), handler_name, test_input)
  158. assert np.allclose(result, handle_square(test_input))
  159. response_received.value = 1
  160. await client.shutdown()
  161. assert not is_process_running(client_pid)
  162. proc.join()
  163. @pytest.mark.parametrize(
  164. "test_input,handle,replicate",
  165. [
  166. pytest.param(np.random.randn(2, 3), handle_square, False, id="square_primary"),
  167. pytest.param(np.random.randn(2, 3), handle_square, True, id="square_replica"),
  168. pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, False, id="add_primary"),
  169. pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, True, id="add_replica"),
  170. ]
  171. )
  172. @pytest.mark.asyncio
  173. async def test_call_peer_numpy(test_input, handle, replicate, handler_name="handle"):
  174. server_primary = await P2P.create()
  175. server = await replicate_if_needed(server_primary, replicate)
  176. await server.add_stream_handler(handler_name, handle)
  177. client_primary = await P2P.create()
  178. client = await replicate_if_needed(client_primary, replicate)
  179. result = await client.call_peer_handler(server.endpoint, handler_name, test_input)
  180. assert np.allclose(result, handle(test_input))
  181. @pytest.mark.parametrize(
  182. "replicate",
  183. [
  184. pytest.param(False, id="primary"),
  185. pytest.param(True, id="replica"),
  186. ]
  187. )
  188. @pytest.mark.asyncio
  189. async def test_call_peer_error(replicate, handler_name="handle"):
  190. server_primary = await P2P.create()
  191. server = await replicate_if_needed(server_primary, replicate)
  192. await server.add_stream_handler(handler_name, handle_add)
  193. client_primary = await P2P.create()
  194. client = await replicate_if_needed(client_primary, replicate)
  195. result = await client.call_peer_handler(server.endpoint, handler_name,
  196. [np.zeros((2, 3)), np.zeros((3, 2))])
  197. assert type(result) == ValueError
  198. @pytest.mark.asyncio
  199. async def test_handlers_on_different_replicas(handler_name="handle"):
  200. def handler(arg, key):
  201. return key
  202. server_primary = await P2P.create()
  203. server_endpoint = server_primary.endpoint
  204. await server_primary.add_stream_handler(handler_name, partial(handler, key="primary"))
  205. server_replica1 = await replicate_if_needed(server_primary, True)
  206. await server_replica1.add_stream_handler(handler_name + "1", partial(handler, key="replica1"))
  207. server_replica2 = await replicate_if_needed(server_primary, True)
  208. await server_replica2.add_stream_handler(handler_name + "2", partial(handler, key="replica2"))
  209. client = await P2P.create()
  210. await asyncio.sleep(1)
  211. result = await client.call_peer_handler(server_endpoint, handler_name, "")
  212. assert result == "primary"
  213. result = await client.call_peer_handler(server_endpoint, handler_name + "1", "")
  214. assert result == "replica1"
  215. result = await client.call_peer_handler(server_endpoint, handler_name + "2", "")
  216. assert result == "replica2"
  217. await server_replica1.stop_listening()
  218. await server_replica2.stop_listening()
  219. # Primary does not handle replicas protocols
  220. with pytest.raises(P2P.IncompleteRead):
  221. await client.call_peer_handler(server_endpoint, handler_name + "1", "")
  222. with pytest.raises(P2P.IncompleteRead):
  223. await client.call_peer_handler(server_endpoint, handler_name + "2", "")
  224. await server_primary.stop_listening()
  225. await server_primary.shutdown()
  226. await client.shutdown()