test_p2p_daemon.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. import asyncio
  2. import multiprocessing as mp
  3. import subprocess
  4. from functools import partial
  5. from typing import List
  6. import numpy as np
  7. import pytest
  8. import torch
  9. from hivemind.p2p import P2P
  10. from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
  11. from hivemind.proto import dht_pb2, runtime_pb2
  12. from hivemind.utils import MSGPackSerializer
  13. from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
  14. def is_process_running(pid: int) -> bool:
  15. return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
  16. async def replicate_if_needed(p2p: P2P, replicate: bool):
  17. return await P2P.replicate(p2p._daemon_listen_port, p2p._host_port) if replicate else p2p
  18. def bootstrap_addr(host_port, id_):
  19. return f'/ip4/127.0.0.1/tcp/{host_port}/p2p/{id_}'
  20. def bootstrap_from(daemons: List[P2P]) -> List[str]:
  21. return [bootstrap_addr(d._host_port, d.id) for d in daemons]
  22. @pytest.mark.asyncio
  23. async def test_daemon_killed_on_del():
  24. p2p_daemon = await P2P.create()
  25. child_pid = p2p_daemon._child.pid
  26. assert is_process_running(child_pid)
  27. await p2p_daemon.shutdown()
  28. assert not is_process_running(child_pid)
  29. @pytest.mark.asyncio
  30. async def test_server_client_connection():
  31. server = await P2P.create()
  32. peers = await server._client.list_peers()
  33. assert len(peers) == 0
  34. nodes = bootstrap_from([server])
  35. client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
  36. await client.wait_for_at_least_n_peers(1)
  37. peers = await client._client.list_peers()
  38. assert len(peers) == 1
  39. peers = await server._client.list_peers()
  40. assert len(peers) == 1
  41. @pytest.mark.asyncio
  42. async def test_daemon_replica_does_not_affect_primary():
  43. p2p_daemon = await P2P.create()
  44. p2p_replica = await P2P.replicate(p2p_daemon._daemon_listen_port, p2p_daemon._host_port)
  45. child_pid = p2p_daemon._child.pid
  46. assert is_process_running(child_pid)
  47. await p2p_replica.shutdown()
  48. assert is_process_running(child_pid)
  49. await p2p_daemon.shutdown()
  50. assert not is_process_running(child_pid)
  51. def handle_square(x):
  52. x = MSGPackSerializer.loads(x)
  53. return MSGPackSerializer.dumps(x ** 2)
  54. def handle_add(args):
  55. args = MSGPackSerializer.loads(args)
  56. result = args[0]
  57. for i in range(1, len(args)):
  58. result = result + args[i]
  59. return MSGPackSerializer.dumps(result)
  60. def handle_square_torch(x):
  61. tensor = runtime_pb2.Tensor()
  62. tensor.ParseFromString(x)
  63. tensor = deserialize_torch_tensor(tensor)
  64. result = tensor ** 2
  65. return serialize_torch_tensor(result).SerializeToString()
  66. def handle_add_torch(args):
  67. args = MSGPackSerializer.loads(args)
  68. tensor = runtime_pb2.Tensor()
  69. tensor.ParseFromString(args[0])
  70. result = deserialize_torch_tensor(tensor)
  71. for i in range(1, len(args)):
  72. tensor = runtime_pb2.Tensor()
  73. tensor.ParseFromString(args[i])
  74. result = result + deserialize_torch_tensor(tensor)
  75. return serialize_torch_tensor(result).SerializeToString()
  76. def handle_add_torch_with_exc(args):
  77. try:
  78. return handle_add_torch(args)
  79. except Exception:
  80. return b'something went wrong :('
  81. @pytest.mark.parametrize(
  82. 'should_cancel,replicate', [
  83. (True, False),
  84. (True, True),
  85. (False, False),
  86. (False, True),
  87. ]
  88. )
  89. @pytest.mark.asyncio
  90. async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
  91. handler_cancelled = False
  92. async def ping_handler(request, context):
  93. try:
  94. await asyncio.sleep(2)
  95. except asyncio.CancelledError:
  96. nonlocal handler_cancelled
  97. handler_cancelled = True
  98. return dht_pb2.PingResponse(
  99. peer=dht_pb2.NodeInfo(
  100. node_id=context.id.encode(), rpc_port=context.port),
  101. sender_endpoint=context.handle_name, available=True)
  102. server_primary = await P2P.create()
  103. server = await replicate_if_needed(server_primary, replicate)
  104. server_pid = server_primary._child.pid
  105. await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
  106. dht_pb2.PingResponse)
  107. assert is_process_running(server_pid)
  108. nodes = bootstrap_from([server])
  109. client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
  110. client = await replicate_if_needed(client_primary, replicate)
  111. client_pid = client_primary._child.pid
  112. assert is_process_running(client_pid)
  113. ping_request = dht_pb2.PingRequest(
  114. peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
  115. validate=True)
  116. expected_response = dht_pb2.PingResponse(
  117. peer=dht_pb2.NodeInfo(node_id=server.id.encode(), rpc_port=server._host_port),
  118. sender_endpoint=handle_name, available=True)
  119. await client.wait_for_at_least_n_peers(1)
  120. libp2p_server_id = PeerID.from_base58(server.id)
  121. stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,))
  122. await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
  123. if should_cancel:
  124. writer.close()
  125. await asyncio.sleep(1)
  126. assert handler_cancelled
  127. else:
  128. result, err = await P2P.receive_protobuf(dht_pb2.PingResponse, reader)
  129. assert err is None
  130. assert result == expected_response
  131. assert not handler_cancelled
  132. await server.stop_listening()
  133. await server_primary.shutdown()
  134. assert not is_process_running(server_pid)
  135. await client_primary.shutdown()
  136. assert not is_process_running(client_pid)
  137. @pytest.mark.asyncio
  138. async def test_call_unary_handler_error(handle_name="handle"):
  139. async def error_handler(request, context):
  140. raise ValueError('boom')
  141. server = await P2P.create()
  142. server_pid = server._child.pid
  143. await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
  144. assert is_process_running(server_pid)
  145. nodes = bootstrap_from([server])
  146. client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
  147. client_pid = client._child.pid
  148. assert is_process_running(client_pid)
  149. await client.wait_for_at_least_n_peers(1)
  150. ping_request = dht_pb2.PingRequest(
  151. peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
  152. validate=True)
  153. libp2p_server_id = PeerID.from_base58(server.id)
  154. stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,))
  155. await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
  156. result, err = await P2P.receive_protobuf(dht_pb2.PingResponse, reader)
  157. assert result is None
  158. assert err.message == 'boom'
  159. await server.stop_listening()
  160. await server.shutdown()
  161. await client.shutdown()
  162. @pytest.mark.parametrize(
  163. "test_input,expected,handle",
  164. [
  165. pytest.param(10, 100, handle_square, id="square_integer"),
  166. pytest.param((1, 2), 3, handle_add, id="add_integers"),
  167. pytest.param(([1, 2, 3], [12, 13]), [1, 2, 3, 12, 13], handle_add, id="add_lists"),
  168. pytest.param(2, 8, lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x) ** 3), id="lambda")
  169. ]
  170. )
  171. @pytest.mark.asyncio
  172. async def test_call_peer_single_process(test_input, expected, handle, handler_name="handle"):
  173. server = await P2P.create()
  174. server_pid = server._child.pid
  175. await server.add_stream_handler(handler_name, handle)
  176. assert is_process_running(server_pid)
  177. nodes = bootstrap_from([server])
  178. client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
  179. client_pid = client._child.pid
  180. assert is_process_running(client_pid)
  181. await client.wait_for_at_least_n_peers(1)
  182. test_input_msgp = MSGPackSerializer.dumps(test_input)
  183. result_msgp = await client.call_peer_handler(server.id, handler_name, test_input_msgp)
  184. result = MSGPackSerializer.loads(result_msgp)
  185. assert result == expected
  186. await server.stop_listening()
  187. await server.shutdown()
  188. assert not is_process_running(server_pid)
  189. await client.shutdown()
  190. assert not is_process_running(client_pid)
  191. async def run_server(handler_name, server_side, client_side, response_received):
  192. server = await P2P.create()
  193. server_pid = server._child.pid
  194. await server.add_stream_handler(handler_name, handle_square)
  195. assert is_process_running(server_pid)
  196. server_side.send(server.id)
  197. server_side.send(server._host_port)
  198. while response_received.value == 0:
  199. await asyncio.sleep(0.5)
  200. await server.stop_listening()
  201. await server.shutdown()
  202. assert not is_process_running(server_pid)
  203. def server_target(handler_name, server_side, client_side, response_received):
  204. asyncio.run(run_server(handler_name, server_side, client_side, response_received))
  205. @pytest.mark.asyncio
  206. async def test_call_peer_different_processes():
  207. handler_name = "square"
  208. test_input = 2
  209. server_side, client_side = mp.Pipe()
  210. response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
  211. response_received.value = 0
  212. proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
  213. proc.start()
  214. peer_id = client_side.recv()
  215. peer_port = client_side.recv()
  216. nodes = [bootstrap_addr(peer_port, peer_id)]
  217. client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
  218. client_pid = client._child.pid
  219. assert is_process_running(client_pid)
  220. await client.wait_for_at_least_n_peers(1)
  221. test_input_msgp = MSGPackSerializer.dumps(2)
  222. result_msgp = await client.call_peer_handler(peer_id, handler_name, test_input_msgp)
  223. result = MSGPackSerializer.loads(result_msgp)
  224. assert np.allclose(result, test_input ** 2)
  225. response_received.value = 1
  226. await client.shutdown()
  227. assert not is_process_running(client_pid)
  228. proc.join()
  229. @pytest.mark.parametrize(
  230. "test_input,expected",
  231. [
  232. pytest.param(torch.tensor([2]), torch.tensor(4)),
  233. pytest.param(
  234. torch.tensor([[1.0, 2.0], [0.5, 0.1]]),
  235. torch.tensor([[1.0, 2.0], [0.5, 0.1]]) ** 2),
  236. ]
  237. )
  238. @pytest.mark.asyncio
  239. async def test_call_peer_torch_square(test_input, expected, handler_name="handle"):
  240. handle = handle_square_torch
  241. server = await P2P.create()
  242. await server.add_stream_handler(handler_name, handle)
  243. nodes = bootstrap_from([server])
  244. client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
  245. await client.wait_for_at_least_n_peers(1)
  246. inp = serialize_torch_tensor(test_input).SerializeToString()
  247. result_pb = await client.call_peer_handler(server.id, handler_name, inp)
  248. result = runtime_pb2.Tensor()
  249. result.ParseFromString(result_pb)
  250. result = deserialize_torch_tensor(result)
  251. assert torch.allclose(result, expected)
  252. await server.stop_listening()
  253. await server.shutdown()
  254. await client.shutdown()
  255. @pytest.mark.parametrize(
  256. "test_input,expected",
  257. [
  258. pytest.param([torch.tensor([1]), torch.tensor([2])], torch.tensor([3])),
  259. pytest.param(
  260. [torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([[1.1, 1.2], [1.3, 1.4]])],
  261. torch.tensor([[1.2, 1.4], [1.6, 1.8]])),
  262. ]
  263. )
  264. @pytest.mark.asyncio
  265. async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
  266. handle = handle_add_torch
  267. server = await P2P.create()
  268. await server.add_stream_handler(handler_name, handle)
  269. nodes = bootstrap_from([server])
  270. client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
  271. await client.wait_for_at_least_n_peers(1)
  272. inp = [serialize_torch_tensor(i).SerializeToString() for i in test_input]
  273. inp_msgp = MSGPackSerializer.dumps(inp)
  274. result_pb = await client.call_peer_handler(server.id, handler_name, inp_msgp)
  275. result = runtime_pb2.Tensor()
  276. result.ParseFromString(result_pb)
  277. result = deserialize_torch_tensor(result)
  278. assert torch.allclose(result, expected)
  279. await server.stop_listening()
  280. await server.shutdown()
  281. await client.shutdown()
  282. @pytest.mark.parametrize(
  283. "replicate",
  284. [
  285. pytest.param(False, id="primary"),
  286. pytest.param(True, id="replica"),
  287. ]
  288. )
  289. @pytest.mark.asyncio
  290. async def test_call_peer_error(replicate, handler_name="handle"):
  291. server_primary = await P2P.create()
  292. server = await replicate_if_needed(server_primary, replicate)
  293. await server.add_stream_handler(handler_name, handle_add_torch_with_exc)
  294. nodes = bootstrap_from([server])
  295. client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
  296. client = await replicate_if_needed(client_primary, replicate)
  297. await client.wait_for_at_least_n_peers(1)
  298. inp = [serialize_torch_tensor(i).SerializeToString() for i in [torch.zeros((2, 3)), torch.zeros((3, 2))]]
  299. inp_msgp = MSGPackSerializer.dumps(inp)
  300. result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
  301. assert result == b'something went wrong :('
  302. await server.stop_listening()
  303. await server_primary.shutdown()
  304. await client_primary.shutdown()
  305. @pytest.mark.asyncio
  306. async def test_handlers_on_different_replicas(handler_name="handle"):
  307. def handler(arg, key):
  308. return key
  309. server_primary = await P2P.create(bootstrap=False)
  310. server_id = server_primary.id
  311. await server_primary.add_stream_handler(handler_name, partial(handler, key=b'primary'))
  312. server_replica1 = await replicate_if_needed(server_primary, True)
  313. await server_replica1.add_stream_handler(handler_name + '1', partial(handler, key=b'replica1'))
  314. server_replica2 = await replicate_if_needed(server_primary, True)
  315. await server_replica2.add_stream_handler(handler_name + '2', partial(handler, key=b'replica2'))
  316. nodes = bootstrap_from([server_primary])
  317. client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
  318. await client.wait_for_at_least_n_peers(1)
  319. result = await client.call_peer_handler(server_id, handler_name, b'1')
  320. assert result == b"primary"
  321. result = await client.call_peer_handler(server_id, handler_name + '1', b'2')
  322. assert result == b"replica1"
  323. result = await client.call_peer_handler(server_id, handler_name + '2', b'3')
  324. assert result == b"replica2"
  325. await server_replica1.stop_listening()
  326. await server_replica2.stop_listening()
  327. # Primary does not handle replicas protocols
  328. with pytest.raises(asyncio.IncompleteReadError):
  329. await client.call_peer_handler(server_id, handler_name + '1', b'')
  330. with pytest.raises(asyncio.IncompleteReadError):
  331. await client.call_peer_handler(server_id, handler_name + '2', b'')
  332. await server_primary.stop_listening()
  333. await server_primary.shutdown()
  334. await client.shutdown()