test_p2p_daemon.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. import asyncio
  2. import multiprocessing as mp
  3. import socket
  4. import subprocess
  5. from functools import partial
  6. from typing import List
  7. import numpy as np
  8. import pytest
  9. import torch
  10. from multiaddr import Multiaddr
  11. from hivemind.p2p import P2P, P2PHandlerError
  12. from hivemind.proto import dht_pb2, runtime_pb2
  13. from hivemind.utils import MSGPackSerializer
  14. from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
  15. from hivemind.utils.networking import find_open_port
  16. def is_process_running(pid: int) -> bool:
  17. return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
  18. async def replicate_if_needed(p2p: P2P, replicate: bool) -> P2P:
  19. return await P2P.replicate(p2p.daemon_listen_maddr) if replicate else p2p
  20. async def bootstrap_from(daemons: List[P2P]) -> List[Multiaddr]:
  21. maddrs = []
  22. for d in daemons:
  23. maddrs += await d.identify_maddrs()
  24. return maddrs
  25. @pytest.mark.asyncio
  26. async def test_daemon_killed_on_del():
  27. p2p_daemon = await P2P.create()
  28. child_pid = p2p_daemon._child.pid
  29. assert is_process_running(child_pid)
  30. await p2p_daemon.shutdown()
  31. assert not is_process_running(child_pid)
  32. @pytest.mark.asyncio
  33. async def test_error_for_wrong_daemon_arguments():
  34. with pytest.raises(RuntimeError):
  35. await P2P.create(unknown_argument=True)
  36. @pytest.mark.asyncio
  37. async def test_server_client_connection():
  38. server = await P2P.create()
  39. peers = await server.list_peers()
  40. assert len(peers) == 0
  41. nodes = await bootstrap_from([server])
  42. client = await P2P.create(bootstrap_peers=nodes)
  43. await client.wait_for_at_least_n_peers(1)
  44. peers = await client.list_peers()
  45. assert len(peers) == 1
  46. peers = await server.list_peers()
  47. assert len(peers) == 1
  48. @pytest.mark.asyncio
  49. async def test_quic_transport():
  50. server_port = find_open_port((socket.AF_INET, socket.SOCK_DGRAM))
  51. server = await P2P.create(quic=True, host_maddrs=[Multiaddr(f'/ip4/127.0.0.1/udp/{server_port}/quic')])
  52. peers = await server.list_peers()
  53. assert len(peers) == 0
  54. nodes = await bootstrap_from([server])
  55. client_port = find_open_port((socket.AF_INET, socket.SOCK_DGRAM))
  56. client = await P2P.create(quic=True, host_maddrs=[Multiaddr(f'/ip4/127.0.0.1/udp/{client_port}/quic')],
  57. bootstrap_peers=nodes)
  58. await client.wait_for_at_least_n_peers(1)
  59. peers = await client.list_peers()
  60. assert len(peers) == 1
  61. peers = await server.list_peers()
  62. assert len(peers) == 1
  63. @pytest.mark.asyncio
  64. async def test_daemon_replica_does_not_affect_primary():
  65. p2p_daemon = await P2P.create()
  66. p2p_replica = await P2P.replicate(p2p_daemon.daemon_listen_maddr)
  67. child_pid = p2p_daemon._child.pid
  68. assert is_process_running(child_pid)
  69. await p2p_replica.shutdown()
  70. assert is_process_running(child_pid)
  71. await p2p_daemon.shutdown()
  72. assert not is_process_running(child_pid)
  73. def handle_square(x):
  74. x = MSGPackSerializer.loads(x)
  75. return MSGPackSerializer.dumps(x ** 2)
  76. def handle_add(args):
  77. args = MSGPackSerializer.loads(args)
  78. result = args[0]
  79. for i in range(1, len(args)):
  80. result = result + args[i]
  81. return MSGPackSerializer.dumps(result)
  82. def handle_square_torch(x):
  83. tensor = runtime_pb2.Tensor()
  84. tensor.ParseFromString(x)
  85. tensor = deserialize_torch_tensor(tensor)
  86. result = tensor ** 2
  87. return serialize_torch_tensor(result).SerializeToString()
  88. def handle_add_torch(args):
  89. args = MSGPackSerializer.loads(args)
  90. tensor = runtime_pb2.Tensor()
  91. tensor.ParseFromString(args[0])
  92. result = deserialize_torch_tensor(tensor)
  93. for i in range(1, len(args)):
  94. tensor = runtime_pb2.Tensor()
  95. tensor.ParseFromString(args[i])
  96. result = result + deserialize_torch_tensor(tensor)
  97. return serialize_torch_tensor(result).SerializeToString()
  98. def handle_add_torch_with_exc(args):
  99. try:
  100. return handle_add_torch(args)
  101. except Exception:
  102. return b'something went wrong :('
  103. @pytest.mark.parametrize(
  104. 'should_cancel,replicate', [
  105. (True, False),
  106. (True, True),
  107. (False, False),
  108. (False, True),
  109. ]
  110. )
  111. @pytest.mark.asyncio
  112. async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
  113. handler_cancelled = False
  114. server_primary = await P2P.create()
  115. server = await replicate_if_needed(server_primary, replicate)
  116. async def ping_handler(request, context):
  117. try:
  118. await asyncio.sleep(2)
  119. except asyncio.CancelledError:
  120. nonlocal handler_cancelled
  121. handler_cancelled = True
  122. return dht_pb2.PingResponse(
  123. peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()),
  124. sender_endpoint=context.handle_name, available=True)
  125. server_pid = server_primary._child.pid
  126. await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
  127. dht_pb2.PingResponse)
  128. assert is_process_running(server_pid)
  129. nodes = await bootstrap_from([server])
  130. client_primary = await P2P.create(bootstrap_peers=nodes)
  131. client = await replicate_if_needed(client_primary, replicate)
  132. client_pid = client_primary._child.pid
  133. assert is_process_running(client_pid)
  134. await client.wait_for_at_least_n_peers(1)
  135. ping_request = dht_pb2.PingRequest(
  136. peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()),
  137. validate=True)
  138. expected_response = dht_pb2.PingResponse(
  139. peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()),
  140. sender_endpoint=handle_name, available=True)
  141. if should_cancel:
  142. stream_info, reader, writer = await client._client.stream_open(server.id, (handle_name,))
  143. await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
  144. writer.close()
  145. await asyncio.sleep(1)
  146. assert handler_cancelled
  147. else:
  148. actual_response = await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
  149. assert actual_response == expected_response
  150. assert not handler_cancelled
  151. await server.shutdown()
  152. await server_primary.shutdown()
  153. assert not is_process_running(server_pid)
  154. await client_primary.shutdown()
  155. assert not is_process_running(client_pid)
  156. @pytest.mark.asyncio
  157. async def test_call_unary_handler_error(handle_name="handle"):
  158. async def error_handler(request, context):
  159. raise ValueError('boom')
  160. server = await P2P.create()
  161. server_pid = server._child.pid
  162. await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
  163. assert is_process_running(server_pid)
  164. nodes = await bootstrap_from([server])
  165. client = await P2P.create(bootstrap_peers=nodes)
  166. client_pid = client._child.pid
  167. assert is_process_running(client_pid)
  168. await client.wait_for_at_least_n_peers(1)
  169. ping_request = dht_pb2.PingRequest(
  170. peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()),
  171. validate=True)
  172. with pytest.raises(P2PHandlerError) as excinfo:
  173. await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
  174. assert 'boom' in str(excinfo.value)
  175. await server.shutdown()
  176. await client.shutdown()
  177. @pytest.mark.parametrize(
  178. "test_input,expected,handle",
  179. [
  180. pytest.param(10, 100, handle_square, id="square_integer"),
  181. pytest.param((1, 2), 3, handle_add, id="add_integers"),
  182. pytest.param(([1, 2, 3], [12, 13]), [1, 2, 3, 12, 13], handle_add, id="add_lists"),
  183. pytest.param(2, 8, lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x) ** 3), id="lambda")
  184. ]
  185. )
  186. @pytest.mark.asyncio
  187. async def test_call_peer_single_process(test_input, expected, handle, handler_name="handle"):
  188. server = await P2P.create()
  189. server_pid = server._child.pid
  190. await server.add_stream_handler(handler_name, handle)
  191. assert is_process_running(server_pid)
  192. nodes = await bootstrap_from([server])
  193. client = await P2P.create(bootstrap_peers=nodes)
  194. client_pid = client._child.pid
  195. assert is_process_running(client_pid)
  196. await client.wait_for_at_least_n_peers(1)
  197. test_input_msgp = MSGPackSerializer.dumps(test_input)
  198. result_msgp = await client.call_peer_handler(server.id, handler_name, test_input_msgp)
  199. result = MSGPackSerializer.loads(result_msgp)
  200. assert result == expected
  201. await server.shutdown()
  202. assert not is_process_running(server_pid)
  203. await client.shutdown()
  204. assert not is_process_running(client_pid)
  205. async def run_server(handler_name, server_side, client_side, response_received):
  206. server = await P2P.create()
  207. server_pid = server._child.pid
  208. await server.add_stream_handler(handler_name, handle_square)
  209. assert is_process_running(server_pid)
  210. server_side.send(server.id)
  211. server_side.send(await server.identify_maddrs())
  212. while response_received.value == 0:
  213. await asyncio.sleep(0.5)
  214. await server.shutdown()
  215. assert not is_process_running(server_pid)
  216. def server_target(handler_name, server_side, client_side, response_received):
  217. asyncio.run(run_server(handler_name, server_side, client_side, response_received))
  218. @pytest.mark.asyncio
  219. async def test_call_peer_different_processes():
  220. handler_name = "square"
  221. test_input = 2
  222. server_side, client_side = mp.Pipe()
  223. response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
  224. response_received.value = 0
  225. proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
  226. proc.start()
  227. peer_id = client_side.recv()
  228. peer_maddrs = client_side.recv()
  229. client = await P2P.create(bootstrap_peers=peer_maddrs)
  230. client_pid = client._child.pid
  231. assert is_process_running(client_pid)
  232. await client.wait_for_at_least_n_peers(1)
  233. test_input_msgp = MSGPackSerializer.dumps(2)
  234. result_msgp = await client.call_peer_handler(peer_id, handler_name, test_input_msgp)
  235. result = MSGPackSerializer.loads(result_msgp)
  236. assert np.allclose(result, test_input ** 2)
  237. response_received.value = 1
  238. await client.shutdown()
  239. assert not is_process_running(client_pid)
  240. proc.join()
  241. @pytest.mark.parametrize(
  242. "test_input,expected",
  243. [
  244. pytest.param(torch.tensor([2]), torch.tensor(4)),
  245. pytest.param(
  246. torch.tensor([[1.0, 2.0], [0.5, 0.1]]),
  247. torch.tensor([[1.0, 2.0], [0.5, 0.1]]) ** 2),
  248. ]
  249. )
  250. @pytest.mark.asyncio
  251. async def test_call_peer_torch_square(test_input, expected, handler_name="handle"):
  252. handle = handle_square_torch
  253. server = await P2P.create()
  254. await server.add_stream_handler(handler_name, handle)
  255. nodes = await bootstrap_from([server])
  256. client = await P2P.create(bootstrap_peers=nodes)
  257. await client.wait_for_at_least_n_peers(1)
  258. inp = serialize_torch_tensor(test_input).SerializeToString()
  259. result_pb = await client.call_peer_handler(server.id, handler_name, inp)
  260. result = runtime_pb2.Tensor()
  261. result.ParseFromString(result_pb)
  262. result = deserialize_torch_tensor(result)
  263. assert torch.allclose(result, expected)
  264. await server.shutdown()
  265. await client.shutdown()
  266. @pytest.mark.parametrize(
  267. "test_input,expected",
  268. [
  269. pytest.param([torch.tensor([1]), torch.tensor([2])], torch.tensor([3])),
  270. pytest.param(
  271. [torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([[1.1, 1.2], [1.3, 1.4]])],
  272. torch.tensor([[1.2, 1.4], [1.6, 1.8]])),
  273. ]
  274. )
  275. @pytest.mark.asyncio
  276. async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
  277. handle = handle_add_torch
  278. server = await P2P.create()
  279. await server.add_stream_handler(handler_name, handle)
  280. nodes = await bootstrap_from([server])
  281. client = await P2P.create(bootstrap_peers=nodes)
  282. await client.wait_for_at_least_n_peers(1)
  283. inp = [serialize_torch_tensor(i).SerializeToString() for i in test_input]
  284. inp_msgp = MSGPackSerializer.dumps(inp)
  285. result_pb = await client.call_peer_handler(server.id, handler_name, inp_msgp)
  286. result = runtime_pb2.Tensor()
  287. result.ParseFromString(result_pb)
  288. result = deserialize_torch_tensor(result)
  289. assert torch.allclose(result, expected)
  290. await server.shutdown()
  291. await client.shutdown()
  292. @pytest.mark.parametrize(
  293. "replicate",
  294. [
  295. pytest.param(False, id="primary"),
  296. pytest.param(True, id="replica"),
  297. ]
  298. )
  299. @pytest.mark.asyncio
  300. async def test_call_peer_error(replicate, handler_name="handle"):
  301. server_primary = await P2P.create()
  302. server = await replicate_if_needed(server_primary, replicate)
  303. await server.add_stream_handler(handler_name, handle_add_torch_with_exc)
  304. nodes = await bootstrap_from([server])
  305. client_primary = await P2P.create(bootstrap_peers=nodes)
  306. client = await replicate_if_needed(client_primary, replicate)
  307. await client.wait_for_at_least_n_peers(1)
  308. inp = [serialize_torch_tensor(i).SerializeToString() for i in [torch.zeros((2, 3)), torch.zeros((3, 2))]]
  309. inp_msgp = MSGPackSerializer.dumps(inp)
  310. result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
  311. assert result == b'something went wrong :('
  312. await server_primary.shutdown()
  313. await server.shutdown()
  314. await client_primary.shutdown()
  315. await client.shutdown()
  316. @pytest.mark.asyncio
  317. async def test_handlers_on_different_replicas(handler_name="handle"):
  318. def handler(arg, key):
  319. return key
  320. server_primary = await P2P.create()
  321. server_id = server_primary.id
  322. await server_primary.add_stream_handler(handler_name, partial(handler, key=b'primary'))
  323. server_replica1 = await replicate_if_needed(server_primary, True)
  324. await server_replica1.add_stream_handler(handler_name + '1', partial(handler, key=b'replica1'))
  325. server_replica2 = await replicate_if_needed(server_primary, True)
  326. await server_replica2.add_stream_handler(handler_name + '2', partial(handler, key=b'replica2'))
  327. nodes = await bootstrap_from([server_primary])
  328. client = await P2P.create(bootstrap_peers=nodes)
  329. await client.wait_for_at_least_n_peers(1)
  330. result = await client.call_peer_handler(server_id, handler_name, b'1')
  331. assert result == b"primary"
  332. result = await client.call_peer_handler(server_id, handler_name + '1', b'2')
  333. assert result == b"replica1"
  334. result = await client.call_peer_handler(server_id, handler_name + '2', b'3')
  335. assert result == b"replica2"
  336. await server_replica1.shutdown()
  337. await server_replica2.shutdown()
  338. # Primary does not handle replicas protocols
  339. with pytest.raises(Exception):
  340. await client.call_peer_handler(server_id, handler_name + '1', b'')
  341. with pytest.raises(Exception):
  342. await client.call_peer_handler(server_id, handler_name + '2', b'')
  343. await server_primary.shutdown()
  344. await client.shutdown()