test_p2p_daemon.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. import asyncio
  2. import multiprocessing as mp
  3. import subprocess
  4. from functools import partial
  5. from typing import List
  6. import numpy as np
  7. import pytest
  8. import torch
  9. from multiaddr import Multiaddr
  10. from hivemind.p2p import P2P, P2PHandlerError
  11. from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
  12. from hivemind.proto import dht_pb2, runtime_pb2
  13. from hivemind.utils import MSGPackSerializer
  14. from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
  15. def is_process_running(pid: int) -> bool:
  16. return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
  17. async def replicate_if_needed(p2p: P2P, replicate: bool) -> P2P:
  18. return await P2P.replicate(p2p._daemon_listen_port, p2p.external_port) if replicate else p2p
  19. def bootstrap_addr(external_port: int, id_: str) -> Multiaddr:
  20. return Multiaddr(f'/ip4/127.0.0.1/tcp/{external_port}/p2p/{id_}')
  21. async def bootstrap_from(daemons: List[P2P]) -> List[Multiaddr]:
  22. maddrs = []
  23. for d in daemons:
  24. maddrs += await d.identify_maddrs()
  25. return maddrs
  26. @pytest.mark.asyncio
  27. async def test_daemon_killed_on_del():
  28. p2p_daemon = await P2P.create()
  29. child_pid = p2p_daemon._child.pid
  30. assert is_process_running(child_pid)
  31. await p2p_daemon.shutdown()
  32. assert not is_process_running(child_pid)
  33. @pytest.mark.asyncio
  34. async def test_server_client_connection():
  35. server = await P2P.create()
  36. peers = await server._client.list_peers()
  37. assert len(peers) == 0
  38. nodes = await bootstrap_from([server])
  39. client = await P2P.create(bootstrap_peers=nodes)
  40. await client.wait_for_at_least_n_peers(1)
  41. peers = await client._client.list_peers()
  42. assert len(peers) == 1
  43. peers = await server._client.list_peers()
  44. assert len(peers) == 1
  45. @pytest.mark.asyncio
  46. async def test_daemon_replica_does_not_affect_primary():
  47. p2p_daemon = await P2P.create()
  48. p2p_replica = await P2P.replicate(p2p_daemon._daemon_listen_port, p2p_daemon.external_port)
  49. child_pid = p2p_daemon._child.pid
  50. assert is_process_running(child_pid)
  51. await p2p_replica.shutdown()
  52. assert is_process_running(child_pid)
  53. await p2p_daemon.shutdown()
  54. assert not is_process_running(child_pid)
  55. def handle_square(x):
  56. x = MSGPackSerializer.loads(x)
  57. return MSGPackSerializer.dumps(x ** 2)
  58. def handle_add(args):
  59. args = MSGPackSerializer.loads(args)
  60. result = args[0]
  61. for i in range(1, len(args)):
  62. result = result + args[i]
  63. return MSGPackSerializer.dumps(result)
  64. def handle_square_torch(x):
  65. tensor = runtime_pb2.Tensor()
  66. tensor.ParseFromString(x)
  67. tensor = deserialize_torch_tensor(tensor)
  68. result = tensor ** 2
  69. return serialize_torch_tensor(result).SerializeToString()
  70. def handle_add_torch(args):
  71. args = MSGPackSerializer.loads(args)
  72. tensor = runtime_pb2.Tensor()
  73. tensor.ParseFromString(args[0])
  74. result = deserialize_torch_tensor(tensor)
  75. for i in range(1, len(args)):
  76. tensor = runtime_pb2.Tensor()
  77. tensor.ParseFromString(args[i])
  78. result = result + deserialize_torch_tensor(tensor)
  79. return serialize_torch_tensor(result).SerializeToString()
  80. def handle_add_torch_with_exc(args):
  81. try:
  82. return handle_add_torch(args)
  83. except Exception:
  84. return b'something went wrong :('
  85. @pytest.mark.parametrize(
  86. 'should_cancel,replicate', [
  87. (True, False),
  88. (True, True),
  89. (False, False),
  90. (False, True),
  91. ]
  92. )
  93. @pytest.mark.asyncio
  94. async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
  95. handler_cancelled = False
  96. server_primary = await P2P.create()
  97. server = await replicate_if_needed(server_primary, replicate)
  98. async def ping_handler(request, context):
  99. try:
  100. await asyncio.sleep(2)
  101. except asyncio.CancelledError:
  102. nonlocal handler_cancelled
  103. handler_cancelled = True
  104. return dht_pb2.PingResponse(
  105. peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes(), rpc_port=server.external_port),
  106. sender_endpoint=context.handle_name, available=True)
  107. server_pid = server_primary._child.pid
  108. await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
  109. dht_pb2.PingResponse)
  110. assert is_process_running(server_pid)
  111. nodes = await bootstrap_from([server])
  112. client_primary = await P2P.create(bootstrap_peers=nodes)
  113. client = await replicate_if_needed(client_primary, replicate)
  114. client_pid = client_primary._child.pid
  115. assert is_process_running(client_pid)
  116. await client.wait_for_at_least_n_peers(1)
  117. ping_request = dht_pb2.PingRequest(
  118. peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes(), rpc_port=client.external_port),
  119. validate=True)
  120. expected_response = dht_pb2.PingResponse(
  121. peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes(), rpc_port=server.external_port),
  122. sender_endpoint=handle_name, available=True)
  123. if should_cancel:
  124. stream_info, reader, writer = await client._client.stream_open(server.id, (handle_name,))
  125. await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
  126. writer.close()
  127. await asyncio.sleep(1)
  128. assert handler_cancelled
  129. else:
  130. actual_response = await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
  131. assert actual_response == expected_response
  132. assert not handler_cancelled
  133. await server.stop_listening()
  134. await server_primary.shutdown()
  135. assert not is_process_running(server_pid)
  136. await client_primary.shutdown()
  137. assert not is_process_running(client_pid)
  138. @pytest.mark.asyncio
  139. async def test_call_unary_handler_error(handle_name="handle"):
  140. async def error_handler(request, context):
  141. raise ValueError('boom')
  142. server = await P2P.create()
  143. server_pid = server._child.pid
  144. await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
  145. assert is_process_running(server_pid)
  146. nodes = await bootstrap_from([server])
  147. client = await P2P.create(bootstrap_peers=nodes)
  148. client_pid = client._child.pid
  149. assert is_process_running(client_pid)
  150. await client.wait_for_at_least_n_peers(1)
  151. ping_request = dht_pb2.PingRequest(
  152. peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes(), rpc_port=client.external_port),
  153. validate=True)
  154. with pytest.raises(P2PHandlerError) as excinfo:
  155. await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
  156. assert 'boom' in str(excinfo.value)
  157. await server.stop_listening()
  158. await server.shutdown()
  159. await client.shutdown()
  160. @pytest.mark.parametrize(
  161. "test_input,expected,handle",
  162. [
  163. pytest.param(10, 100, handle_square, id="square_integer"),
  164. pytest.param((1, 2), 3, handle_add, id="add_integers"),
  165. pytest.param(([1, 2, 3], [12, 13]), [1, 2, 3, 12, 13], handle_add, id="add_lists"),
  166. pytest.param(2, 8, lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x) ** 3), id="lambda")
  167. ]
  168. )
  169. @pytest.mark.asyncio
  170. async def test_call_peer_single_process(test_input, expected, handle, handler_name="handle"):
  171. server = await P2P.create()
  172. server_pid = server._child.pid
  173. await server.add_stream_handler(handler_name, handle)
  174. assert is_process_running(server_pid)
  175. nodes = await bootstrap_from([server])
  176. client = await P2P.create(bootstrap_peers=nodes)
  177. client_pid = client._child.pid
  178. assert is_process_running(client_pid)
  179. await client.wait_for_at_least_n_peers(1)
  180. test_input_msgp = MSGPackSerializer.dumps(test_input)
  181. result_msgp = await client.call_peer_handler(server.id, handler_name, test_input_msgp)
  182. result = MSGPackSerializer.loads(result_msgp)
  183. assert result == expected
  184. await server.stop_listening()
  185. await server.shutdown()
  186. assert not is_process_running(server_pid)
  187. await client.shutdown()
  188. assert not is_process_running(client_pid)
  189. async def run_server(handler_name, server_side, client_side, response_received):
  190. server = await P2P.create()
  191. server_pid = server._child.pid
  192. await server.add_stream_handler(handler_name, handle_square)
  193. assert is_process_running(server_pid)
  194. server_side.send(server.id)
  195. server_side.send(server.external_port)
  196. while response_received.value == 0:
  197. await asyncio.sleep(0.5)
  198. await server.stop_listening()
  199. await server.shutdown()
  200. assert not is_process_running(server_pid)
  201. def server_target(handler_name, server_side, client_side, response_received):
  202. asyncio.run(run_server(handler_name, server_side, client_side, response_received))
  203. @pytest.mark.asyncio
  204. async def test_call_peer_different_processes():
  205. handler_name = "square"
  206. test_input = 2
  207. server_side, client_side = mp.Pipe()
  208. response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
  209. response_received.value = 0
  210. proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
  211. proc.start()
  212. peer_id = client_side.recv()
  213. peer_port = client_side.recv()
  214. nodes = [bootstrap_addr(peer_port, peer_id)]
  215. client = await P2P.create(bootstrap_peers=nodes)
  216. client_pid = client._child.pid
  217. assert is_process_running(client_pid)
  218. await client.wait_for_at_least_n_peers(1)
  219. test_input_msgp = MSGPackSerializer.dumps(2)
  220. result_msgp = await client.call_peer_handler(peer_id, handler_name, test_input_msgp)
  221. result = MSGPackSerializer.loads(result_msgp)
  222. assert np.allclose(result, test_input ** 2)
  223. response_received.value = 1
  224. await client.shutdown()
  225. assert not is_process_running(client_pid)
  226. proc.join()
  227. @pytest.mark.parametrize(
  228. "test_input,expected",
  229. [
  230. pytest.param(torch.tensor([2]), torch.tensor(4)),
  231. pytest.param(
  232. torch.tensor([[1.0, 2.0], [0.5, 0.1]]),
  233. torch.tensor([[1.0, 2.0], [0.5, 0.1]]) ** 2),
  234. ]
  235. )
  236. @pytest.mark.asyncio
  237. async def test_call_peer_torch_square(test_input, expected, handler_name="handle"):
  238. handle = handle_square_torch
  239. server = await P2P.create()
  240. await server.add_stream_handler(handler_name, handle)
  241. nodes = await bootstrap_from([server])
  242. client = await P2P.create(bootstrap_peers=nodes)
  243. await client.wait_for_at_least_n_peers(1)
  244. inp = serialize_torch_tensor(test_input).SerializeToString()
  245. result_pb = await client.call_peer_handler(server.id, handler_name, inp)
  246. result = runtime_pb2.Tensor()
  247. result.ParseFromString(result_pb)
  248. result = deserialize_torch_tensor(result)
  249. assert torch.allclose(result, expected)
  250. await server.stop_listening()
  251. await server.shutdown()
  252. await client.shutdown()
  253. @pytest.mark.parametrize(
  254. "test_input,expected",
  255. [
  256. pytest.param([torch.tensor([1]), torch.tensor([2])], torch.tensor([3])),
  257. pytest.param(
  258. [torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([[1.1, 1.2], [1.3, 1.4]])],
  259. torch.tensor([[1.2, 1.4], [1.6, 1.8]])),
  260. ]
  261. )
  262. @pytest.mark.asyncio
  263. async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
  264. handle = handle_add_torch
  265. server = await P2P.create()
  266. await server.add_stream_handler(handler_name, handle)
  267. nodes = await bootstrap_from([server])
  268. client = await P2P.create(bootstrap_peers=nodes)
  269. await client.wait_for_at_least_n_peers(1)
  270. inp = [serialize_torch_tensor(i).SerializeToString() for i in test_input]
  271. inp_msgp = MSGPackSerializer.dumps(inp)
  272. result_pb = await client.call_peer_handler(server.id, handler_name, inp_msgp)
  273. result = runtime_pb2.Tensor()
  274. result.ParseFromString(result_pb)
  275. result = deserialize_torch_tensor(result)
  276. assert torch.allclose(result, expected)
  277. await server.stop_listening()
  278. await server.shutdown()
  279. await client.shutdown()
  280. @pytest.mark.parametrize(
  281. "replicate",
  282. [
  283. pytest.param(False, id="primary"),
  284. pytest.param(True, id="replica"),
  285. ]
  286. )
  287. @pytest.mark.asyncio
  288. async def test_call_peer_error(replicate, handler_name="handle"):
  289. server_primary = await P2P.create()
  290. server = await replicate_if_needed(server_primary, replicate)
  291. await server.add_stream_handler(handler_name, handle_add_torch_with_exc)
  292. nodes = await bootstrap_from([server])
  293. client_primary = await P2P.create(bootstrap_peers=nodes)
  294. client = await replicate_if_needed(client_primary, replicate)
  295. await client.wait_for_at_least_n_peers(1)
  296. inp = [serialize_torch_tensor(i).SerializeToString() for i in [torch.zeros((2, 3)), torch.zeros((3, 2))]]
  297. inp_msgp = MSGPackSerializer.dumps(inp)
  298. result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
  299. assert result == b'something went wrong :('
  300. await server.stop_listening()
  301. await server_primary.shutdown()
  302. await client_primary.shutdown()
  303. @pytest.mark.asyncio
  304. async def test_handlers_on_different_replicas(handler_name="handle"):
  305. def handler(arg, key):
  306. return key
  307. server_primary = await P2P.create()
  308. server_id = server_primary.id
  309. await server_primary.add_stream_handler(handler_name, partial(handler, key=b'primary'))
  310. server_replica1 = await replicate_if_needed(server_primary, True)
  311. await server_replica1.add_stream_handler(handler_name + '1', partial(handler, key=b'replica1'))
  312. server_replica2 = await replicate_if_needed(server_primary, True)
  313. await server_replica2.add_stream_handler(handler_name + '2', partial(handler, key=b'replica2'))
  314. nodes = await bootstrap_from([server_primary])
  315. client = await P2P.create(bootstrap_peers=nodes)
  316. await client.wait_for_at_least_n_peers(1)
  317. result = await client.call_peer_handler(server_id, handler_name, b'1')
  318. assert result == b"primary"
  319. result = await client.call_peer_handler(server_id, handler_name + '1', b'2')
  320. assert result == b"replica1"
  321. result = await client.call_peer_handler(server_id, handler_name + '2', b'3')
  322. assert result == b"replica2"
  323. await server_replica1.stop_listening()
  324. await server_replica2.stop_listening()
  325. # Primary does not handle replicas protocols
  326. with pytest.raises(Exception):
  327. await client.call_peer_handler(server_id, handler_name + '1', b'')
  328. with pytest.raises(Exception):
  329. await client.call_peer_handler(server_id, handler_name + '2', b'')
  330. await server_primary.stop_listening()
  331. await server_primary.shutdown()
  332. await client.shutdown()