test_p2p_daemon.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. import asyncio
  2. import multiprocessing as mp
  3. import subprocess
  4. from functools import partial
  5. from typing import List
  6. import numpy as np
  7. import pytest
  8. import torch
  9. from multiaddr import Multiaddr
  10. from hivemind.p2p import P2P, P2PHandlerError
  11. from hivemind.proto import dht_pb2, runtime_pb2
  12. from hivemind.utils import MSGPackSerializer
  13. from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
  14. def is_process_running(pid: int) -> bool:
  15. return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
  16. async def replicate_if_needed(p2p: P2P, replicate: bool) -> P2P:
  17. return await P2P.replicate(p2p.daemon_listen_maddr) if replicate else p2p
  18. async def bootstrap_from(daemons: List[P2P]) -> List[Multiaddr]:
  19. maddrs = []
  20. for d in daemons:
  21. maddrs += await d.get_visible_maddrs()
  22. return maddrs
  23. @pytest.mark.asyncio
  24. async def test_daemon_killed_on_del():
  25. p2p_daemon = await P2P.create()
  26. child_pid = p2p_daemon._child.pid
  27. assert is_process_running(child_pid)
  28. await p2p_daemon.shutdown()
  29. assert not is_process_running(child_pid)
  30. @pytest.mark.parametrize(
  31. 'host_maddrs', [
  32. [Multiaddr('/ip4/127.0.0.1/tcp/0')],
  33. [Multiaddr('/ip4/127.0.0.1/udp/0/quic')],
  34. [Multiaddr('/ip4/127.0.0.1/tcp/0'), Multiaddr('/ip4/127.0.0.1/udp/0/quic')],
  35. ]
  36. )
  37. @pytest.mark.asyncio
  38. async def test_transports(host_maddrs: List[Multiaddr]):
  39. server = await P2P.create(quic=True, host_maddrs=host_maddrs)
  40. peers = await server.list_peers()
  41. assert len(peers) == 0
  42. nodes = await bootstrap_from([server])
  43. client = await P2P.create(quic=True, host_maddrs=host_maddrs, initial_peers=nodes)
  44. await client.wait_for_at_least_n_peers(1)
  45. peers = await client.list_peers()
  46. assert len(peers) == 1
  47. peers = await server.list_peers()
  48. assert len(peers) == 1
  49. @pytest.mark.asyncio
  50. async def test_daemon_replica_does_not_affect_primary():
  51. p2p_daemon = await P2P.create()
  52. p2p_replica = await P2P.replicate(p2p_daemon.daemon_listen_maddr)
  53. child_pid = p2p_daemon._child.pid
  54. assert is_process_running(child_pid)
  55. await p2p_replica.shutdown()
  56. assert is_process_running(child_pid)
  57. await p2p_daemon.shutdown()
  58. assert not is_process_running(child_pid)
  59. def handle_square(x):
  60. x = MSGPackSerializer.loads(x)
  61. return MSGPackSerializer.dumps(x ** 2)
  62. def handle_add(args):
  63. args = MSGPackSerializer.loads(args)
  64. result = args[0]
  65. for i in range(1, len(args)):
  66. result = result + args[i]
  67. return MSGPackSerializer.dumps(result)
  68. def handle_square_torch(x):
  69. tensor = runtime_pb2.Tensor()
  70. tensor.ParseFromString(x)
  71. tensor = deserialize_torch_tensor(tensor)
  72. result = tensor ** 2
  73. return serialize_torch_tensor(result).SerializeToString()
  74. def handle_add_torch(args):
  75. args = MSGPackSerializer.loads(args)
  76. tensor = runtime_pb2.Tensor()
  77. tensor.ParseFromString(args[0])
  78. result = deserialize_torch_tensor(tensor)
  79. for i in range(1, len(args)):
  80. tensor = runtime_pb2.Tensor()
  81. tensor.ParseFromString(args[i])
  82. result = result + deserialize_torch_tensor(tensor)
  83. return serialize_torch_tensor(result).SerializeToString()
  84. def handle_add_torch_with_exc(args):
  85. try:
  86. return handle_add_torch(args)
  87. except Exception:
  88. return b'something went wrong :('
  89. @pytest.mark.parametrize(
  90. 'should_cancel,replicate', [
  91. (True, False),
  92. (True, True),
  93. (False, False),
  94. (False, True),
  95. ]
  96. )
  97. @pytest.mark.asyncio
  98. async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
  99. handler_cancelled = False
  100. server_primary = await P2P.create()
  101. server = await replicate_if_needed(server_primary, replicate)
  102. async def ping_handler(request, context):
  103. try:
  104. await asyncio.sleep(2)
  105. except asyncio.CancelledError:
  106. nonlocal handler_cancelled
  107. handler_cancelled = True
  108. return dht_pb2.PingResponse(
  109. peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()),
  110. available=True)
  111. server_pid = server_primary._child.pid
  112. await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
  113. dht_pb2.PingResponse)
  114. assert is_process_running(server_pid)
  115. nodes = await bootstrap_from([server])
  116. client_primary = await P2P.create(initial_peers=nodes)
  117. client = await replicate_if_needed(client_primary, replicate)
  118. client_pid = client_primary._child.pid
  119. assert is_process_running(client_pid)
  120. await client.wait_for_at_least_n_peers(1)
  121. ping_request = dht_pb2.PingRequest(
  122. peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()),
  123. validate=True)
  124. expected_response = dht_pb2.PingResponse(
  125. peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()),
  126. available=True)
  127. if should_cancel:
  128. stream_info, reader, writer = await client._client.stream_open(server.id, (handle_name,))
  129. await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
  130. writer.close()
  131. await asyncio.sleep(1)
  132. assert handler_cancelled
  133. else:
  134. actual_response = await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
  135. assert actual_response == expected_response
  136. assert not handler_cancelled
  137. await server.shutdown()
  138. await server_primary.shutdown()
  139. assert not is_process_running(server_pid)
  140. await client_primary.shutdown()
  141. assert not is_process_running(client_pid)
  142. @pytest.mark.asyncio
  143. async def test_call_unary_handler_error(handle_name="handle"):
  144. async def error_handler(request, context):
  145. raise ValueError('boom')
  146. server = await P2P.create()
  147. server_pid = server._child.pid
  148. await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
  149. assert is_process_running(server_pid)
  150. nodes = await bootstrap_from([server])
  151. client = await P2P.create(initial_peers=nodes)
  152. client_pid = client._child.pid
  153. assert is_process_running(client_pid)
  154. await client.wait_for_at_least_n_peers(1)
  155. ping_request = dht_pb2.PingRequest(
  156. peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()),
  157. validate=True)
  158. with pytest.raises(P2PHandlerError) as excinfo:
  159. await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
  160. assert 'boom' in str(excinfo.value)
  161. await server.shutdown()
  162. await client.shutdown()
  163. @pytest.mark.parametrize(
  164. "test_input,expected,handle",
  165. [
  166. pytest.param(10, 100, handle_square, id="square_integer"),
  167. pytest.param((1, 2), 3, handle_add, id="add_integers"),
  168. pytest.param(([1, 2, 3], [12, 13]), [1, 2, 3, 12, 13], handle_add, id="add_lists"),
  169. pytest.param(2, 8, lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x) ** 3), id="lambda")
  170. ]
  171. )
  172. @pytest.mark.asyncio
  173. async def test_call_peer_single_process(test_input, expected, handle, handler_name="handle"):
  174. server = await P2P.create()
  175. server_pid = server._child.pid
  176. await server.add_stream_handler(handler_name, handle)
  177. assert is_process_running(server_pid)
  178. nodes = await bootstrap_from([server])
  179. client = await P2P.create(initial_peers=nodes)
  180. client_pid = client._child.pid
  181. assert is_process_running(client_pid)
  182. await client.wait_for_at_least_n_peers(1)
  183. test_input_msgp = MSGPackSerializer.dumps(test_input)
  184. result_msgp = await client.call_peer_handler(server.id, handler_name, test_input_msgp)
  185. result = MSGPackSerializer.loads(result_msgp)
  186. assert result == expected
  187. await server.shutdown()
  188. assert not is_process_running(server_pid)
  189. await client.shutdown()
  190. assert not is_process_running(client_pid)
  191. async def run_server(handler_name, server_side, client_side, response_received):
  192. server = await P2P.create()
  193. server_pid = server._child.pid
  194. await server.add_stream_handler(handler_name, handle_square)
  195. assert is_process_running(server_pid)
  196. server_side.send(server.id)
  197. server_side.send(await server.get_visible_maddrs())
  198. while response_received.value == 0:
  199. await asyncio.sleep(0.5)
  200. await server.shutdown()
  201. assert not is_process_running(server_pid)
  202. def server_target(handler_name, server_side, client_side, response_received):
  203. asyncio.run(run_server(handler_name, server_side, client_side, response_received))
  204. @pytest.mark.asyncio
  205. async def test_call_peer_different_processes():
  206. handler_name = "square"
  207. test_input = 2
  208. server_side, client_side = mp.Pipe()
  209. response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
  210. response_received.value = 0
  211. proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
  212. proc.start()
  213. peer_id = client_side.recv()
  214. peer_maddrs = client_side.recv()
  215. client = await P2P.create(initial_peers=peer_maddrs)
  216. client_pid = client._child.pid
  217. assert is_process_running(client_pid)
  218. await client.wait_for_at_least_n_peers(1)
  219. test_input_msgp = MSGPackSerializer.dumps(2)
  220. result_msgp = await client.call_peer_handler(peer_id, handler_name, test_input_msgp)
  221. result = MSGPackSerializer.loads(result_msgp)
  222. assert np.allclose(result, test_input ** 2)
  223. response_received.value = 1
  224. await client.shutdown()
  225. assert not is_process_running(client_pid)
  226. proc.join()
  227. @pytest.mark.parametrize(
  228. "test_input,expected",
  229. [
  230. pytest.param(torch.tensor([2]), torch.tensor(4)),
  231. pytest.param(
  232. torch.tensor([[1.0, 2.0], [0.5, 0.1]]),
  233. torch.tensor([[1.0, 2.0], [0.5, 0.1]]) ** 2),
  234. ]
  235. )
  236. @pytest.mark.asyncio
  237. async def test_call_peer_torch_square(test_input, expected, handler_name="handle"):
  238. handle = handle_square_torch
  239. server = await P2P.create()
  240. await server.add_stream_handler(handler_name, handle)
  241. nodes = await bootstrap_from([server])
  242. client = await P2P.create(initial_peers=nodes)
  243. await client.wait_for_at_least_n_peers(1)
  244. inp = serialize_torch_tensor(test_input).SerializeToString()
  245. result_pb = await client.call_peer_handler(server.id, handler_name, inp)
  246. result = runtime_pb2.Tensor()
  247. result.ParseFromString(result_pb)
  248. result = deserialize_torch_tensor(result)
  249. assert torch.allclose(result, expected)
  250. await server.shutdown()
  251. await client.shutdown()
  252. @pytest.mark.parametrize(
  253. "test_input,expected",
  254. [
  255. pytest.param([torch.tensor([1]), torch.tensor([2])], torch.tensor([3])),
  256. pytest.param(
  257. [torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([[1.1, 1.2], [1.3, 1.4]])],
  258. torch.tensor([[1.2, 1.4], [1.6, 1.8]])),
  259. ]
  260. )
  261. @pytest.mark.asyncio
  262. async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
  263. handle = handle_add_torch
  264. server = await P2P.create()
  265. await server.add_stream_handler(handler_name, handle)
  266. nodes = await bootstrap_from([server])
  267. client = await P2P.create(initial_peers=nodes)
  268. await client.wait_for_at_least_n_peers(1)
  269. inp = [serialize_torch_tensor(i).SerializeToString() for i in test_input]
  270. inp_msgp = MSGPackSerializer.dumps(inp)
  271. result_pb = await client.call_peer_handler(server.id, handler_name, inp_msgp)
  272. result = runtime_pb2.Tensor()
  273. result.ParseFromString(result_pb)
  274. result = deserialize_torch_tensor(result)
  275. assert torch.allclose(result, expected)
  276. await server.shutdown()
  277. await client.shutdown()
  278. @pytest.mark.parametrize(
  279. "replicate",
  280. [
  281. pytest.param(False, id="primary"),
  282. pytest.param(True, id="replica"),
  283. ]
  284. )
  285. @pytest.mark.asyncio
  286. async def test_call_peer_error(replicate, handler_name="handle"):
  287. server_primary = await P2P.create()
  288. server = await replicate_if_needed(server_primary, replicate)
  289. await server.add_stream_handler(handler_name, handle_add_torch_with_exc)
  290. nodes = await bootstrap_from([server])
  291. client_primary = await P2P.create(initial_peers=nodes)
  292. client = await replicate_if_needed(client_primary, replicate)
  293. await client.wait_for_at_least_n_peers(1)
  294. inp = [serialize_torch_tensor(i).SerializeToString() for i in [torch.zeros((2, 3)), torch.zeros((3, 2))]]
  295. inp_msgp = MSGPackSerializer.dumps(inp)
  296. result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
  297. assert result == b'something went wrong :('
  298. await server_primary.shutdown()
  299. await server.shutdown()
  300. await client_primary.shutdown()
  301. await client.shutdown()
  302. @pytest.mark.asyncio
  303. async def test_handlers_on_different_replicas(handler_name="handle"):
  304. def handler(arg, key):
  305. return key
  306. server_primary = await P2P.create()
  307. server_id = server_primary.id
  308. await server_primary.add_stream_handler(handler_name, partial(handler, key=b'primary'))
  309. server_replica1 = await replicate_if_needed(server_primary, True)
  310. await server_replica1.add_stream_handler(handler_name + '1', partial(handler, key=b'replica1'))
  311. server_replica2 = await replicate_if_needed(server_primary, True)
  312. await server_replica2.add_stream_handler(handler_name + '2', partial(handler, key=b'replica2'))
  313. nodes = await bootstrap_from([server_primary])
  314. client = await P2P.create(initial_peers=nodes)
  315. await client.wait_for_at_least_n_peers(1)
  316. result = await client.call_peer_handler(server_id, handler_name, b'1')
  317. assert result == b"primary"
  318. result = await client.call_peer_handler(server_id, handler_name + '1', b'2')
  319. assert result == b"replica1"
  320. result = await client.call_peer_handler(server_id, handler_name + '2', b'3')
  321. assert result == b"replica2"
  322. await server_replica1.shutdown()
  323. await server_replica2.shutdown()
  324. # Primary does not handle replicas protocols
  325. with pytest.raises(Exception):
  326. await client.call_peer_handler(server_id, handler_name + '1', b'')
  327. with pytest.raises(Exception):
  328. await client.call_peer_handler(server_id, handler_name + '2', b'')
  329. await server_primary.shutdown()
  330. await client.shutdown()