test_p2p_daemon.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. import asyncio
  2. import multiprocessing as mp
  3. import subprocess
  4. from functools import partial
  5. from typing import List
  6. import numpy as np
  7. import pytest
  8. import torch
  9. from multiaddr import Multiaddr
  10. from hivemind.p2p import P2P, P2PHandlerError
  11. from hivemind.proto import dht_pb2, runtime_pb2
  12. from hivemind.utils import MSGPackSerializer
  13. from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
  14. def is_process_running(pid: int) -> bool:
  15. return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
  16. async def replicate_if_needed(p2p: P2P, replicate: bool) -> P2P:
  17. return await P2P.replicate(p2p.daemon_listen_maddr) if replicate else p2p
  18. async def bootstrap_from(daemons: List[P2P]) -> List[Multiaddr]:
  19. maddrs = []
  20. for d in daemons:
  21. maddrs += await d.get_visible_maddrs()
  22. return maddrs
  23. @pytest.mark.asyncio
  24. async def test_daemon_killed_on_del():
  25. p2p_daemon = await P2P.create()
  26. child_pid = p2p_daemon._child.pid
  27. assert is_process_running(child_pid)
  28. await p2p_daemon.shutdown()
  29. assert not is_process_running(child_pid)
  30. @pytest.mark.parametrize(
  31. "host_maddrs",
  32. [
  33. [Multiaddr("/ip4/127.0.0.1/tcp/0")],
  34. [Multiaddr("/ip4/127.0.0.1/udp/0/quic")],
  35. [Multiaddr("/ip4/127.0.0.1/tcp/0"), Multiaddr("/ip4/127.0.0.1/udp/0/quic")],
  36. ],
  37. )
  38. @pytest.mark.asyncio
  39. async def test_transports(host_maddrs: List[Multiaddr]):
  40. server = await P2P.create(quic=True, host_maddrs=host_maddrs)
  41. peers = await server.list_peers()
  42. assert len(peers) == 0
  43. nodes = await bootstrap_from([server])
  44. client = await P2P.create(quic=True, host_maddrs=host_maddrs, initial_peers=nodes)
  45. await client.wait_for_at_least_n_peers(1)
  46. peers = await client.list_peers()
  47. assert len(peers) == 1
  48. peers = await server.list_peers()
  49. assert len(peers) == 1
  50. @pytest.mark.asyncio
  51. async def test_daemon_replica_does_not_affect_primary():
  52. p2p_daemon = await P2P.create()
  53. p2p_replica = await P2P.replicate(p2p_daemon.daemon_listen_maddr)
  54. child_pid = p2p_daemon._child.pid
  55. assert is_process_running(child_pid)
  56. await p2p_replica.shutdown()
  57. assert is_process_running(child_pid)
  58. await p2p_daemon.shutdown()
  59. assert not is_process_running(child_pid)
  60. def handle_square(x):
  61. x = MSGPackSerializer.loads(x)
  62. return MSGPackSerializer.dumps(x ** 2)
  63. def handle_add(args):
  64. args = MSGPackSerializer.loads(args)
  65. result = args[0]
  66. for i in range(1, len(args)):
  67. result = result + args[i]
  68. return MSGPackSerializer.dumps(result)
  69. def handle_square_torch(x):
  70. tensor = runtime_pb2.Tensor()
  71. tensor.ParseFromString(x)
  72. tensor = deserialize_torch_tensor(tensor)
  73. result = tensor ** 2
  74. return serialize_torch_tensor(result).SerializeToString()
  75. def handle_add_torch(args):
  76. args = MSGPackSerializer.loads(args)
  77. tensor = runtime_pb2.Tensor()
  78. tensor.ParseFromString(args[0])
  79. result = deserialize_torch_tensor(tensor)
  80. for i in range(1, len(args)):
  81. tensor = runtime_pb2.Tensor()
  82. tensor.ParseFromString(args[i])
  83. result = result + deserialize_torch_tensor(tensor)
  84. return serialize_torch_tensor(result).SerializeToString()
  85. def handle_add_torch_with_exc(args):
  86. try:
  87. return handle_add_torch(args)
  88. except Exception:
  89. return b"something went wrong :("
  90. @pytest.mark.parametrize(
  91. "should_cancel,replicate",
  92. [
  93. (True, False),
  94. (True, True),
  95. (False, False),
  96. (False, True),
  97. ],
  98. )
  99. @pytest.mark.asyncio
  100. async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
  101. handler_cancelled = False
  102. server_primary = await P2P.create()
  103. server = await replicate_if_needed(server_primary, replicate)
  104. async def ping_handler(request, context):
  105. try:
  106. await asyncio.sleep(2)
  107. except asyncio.CancelledError:
  108. nonlocal handler_cancelled
  109. handler_cancelled = True
  110. return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
  111. server_pid = server_primary._child.pid
  112. await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
  113. assert is_process_running(server_pid)
  114. nodes = await bootstrap_from([server])
  115. client_primary = await P2P.create(initial_peers=nodes)
  116. client = await replicate_if_needed(client_primary, replicate)
  117. client_pid = client_primary._child.pid
  118. assert is_process_running(client_pid)
  119. await client.wait_for_at_least_n_peers(1)
  120. ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
  121. expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
  122. if should_cancel:
  123. stream_info, reader, writer = await client._client.stream_open(server.id, (handle_name,))
  124. await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
  125. writer.close()
  126. await asyncio.sleep(1)
  127. assert handler_cancelled
  128. else:
  129. actual_response = await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
  130. assert actual_response == expected_response
  131. assert not handler_cancelled
  132. await server.shutdown()
  133. await server_primary.shutdown()
  134. assert not is_process_running(server_pid)
  135. await client_primary.shutdown()
  136. assert not is_process_running(client_pid)
  137. @pytest.mark.asyncio
  138. async def test_call_unary_handler_error(handle_name="handle"):
  139. async def error_handler(request, context):
  140. raise ValueError("boom")
  141. server = await P2P.create()
  142. server_pid = server._child.pid
  143. await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
  144. assert is_process_running(server_pid)
  145. nodes = await bootstrap_from([server])
  146. client = await P2P.create(initial_peers=nodes)
  147. client_pid = client._child.pid
  148. assert is_process_running(client_pid)
  149. await client.wait_for_at_least_n_peers(1)
  150. ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
  151. with pytest.raises(P2PHandlerError) as excinfo:
  152. await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
  153. assert "boom" in str(excinfo.value)
  154. await server.shutdown()
  155. await client.shutdown()
  156. @pytest.mark.parametrize(
  157. "test_input,expected,handle",
  158. [
  159. pytest.param(10, 100, handle_square, id="square_integer"),
  160. pytest.param((1, 2), 3, handle_add, id="add_integers"),
  161. pytest.param(([1, 2, 3], [12, 13]), [1, 2, 3, 12, 13], handle_add, id="add_lists"),
  162. pytest.param(2, 8, lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x) ** 3), id="lambda"),
  163. ],
  164. )
  165. @pytest.mark.asyncio
  166. async def test_call_peer_single_process(test_input, expected, handle, handler_name="handle"):
  167. server = await P2P.create()
  168. server_pid = server._child.pid
  169. await server.add_stream_handler(handler_name, handle)
  170. assert is_process_running(server_pid)
  171. nodes = await bootstrap_from([server])
  172. client = await P2P.create(initial_peers=nodes)
  173. client_pid = client._child.pid
  174. assert is_process_running(client_pid)
  175. await client.wait_for_at_least_n_peers(1)
  176. test_input_msgp = MSGPackSerializer.dumps(test_input)
  177. result_msgp = await client.call_peer_handler(server.id, handler_name, test_input_msgp)
  178. result = MSGPackSerializer.loads(result_msgp)
  179. assert result == expected
  180. await server.shutdown()
  181. assert not is_process_running(server_pid)
  182. await client.shutdown()
  183. assert not is_process_running(client_pid)
  184. async def run_server(handler_name, server_side, client_side, response_received):
  185. server = await P2P.create()
  186. server_pid = server._child.pid
  187. await server.add_stream_handler(handler_name, handle_square)
  188. assert is_process_running(server_pid)
  189. server_side.send(server.id)
  190. server_side.send(await server.get_visible_maddrs())
  191. while response_received.value == 0:
  192. await asyncio.sleep(0.5)
  193. await server.shutdown()
  194. assert not is_process_running(server_pid)
  195. def server_target(handler_name, server_side, client_side, response_received):
  196. asyncio.run(run_server(handler_name, server_side, client_side, response_received))
  197. @pytest.mark.asyncio
  198. async def test_call_peer_different_processes():
  199. handler_name = "square"
  200. test_input = 2
  201. server_side, client_side = mp.Pipe()
  202. response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
  203. response_received.value = 0
  204. proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
  205. proc.start()
  206. peer_id = client_side.recv()
  207. peer_maddrs = client_side.recv()
  208. client = await P2P.create(initial_peers=peer_maddrs)
  209. client_pid = client._child.pid
  210. assert is_process_running(client_pid)
  211. await client.wait_for_at_least_n_peers(1)
  212. test_input_msgp = MSGPackSerializer.dumps(2)
  213. result_msgp = await client.call_peer_handler(peer_id, handler_name, test_input_msgp)
  214. result = MSGPackSerializer.loads(result_msgp)
  215. assert np.allclose(result, test_input ** 2)
  216. response_received.value = 1
  217. await client.shutdown()
  218. assert not is_process_running(client_pid)
  219. proc.join()
  220. @pytest.mark.parametrize(
  221. "test_input,expected",
  222. [
  223. pytest.param(torch.tensor([2]), torch.tensor(4)),
  224. pytest.param(torch.tensor([[1.0, 2.0], [0.5, 0.1]]), torch.tensor([[1.0, 2.0], [0.5, 0.1]]) ** 2),
  225. ],
  226. )
  227. @pytest.mark.asyncio
  228. async def test_call_peer_torch_square(test_input, expected, handler_name="handle"):
  229. handle = handle_square_torch
  230. server = await P2P.create()
  231. await server.add_stream_handler(handler_name, handle)
  232. nodes = await bootstrap_from([server])
  233. client = await P2P.create(initial_peers=nodes)
  234. await client.wait_for_at_least_n_peers(1)
  235. inp = serialize_torch_tensor(test_input).SerializeToString()
  236. result_pb = await client.call_peer_handler(server.id, handler_name, inp)
  237. result = runtime_pb2.Tensor()
  238. result.ParseFromString(result_pb)
  239. result = deserialize_torch_tensor(result)
  240. assert torch.allclose(result, expected)
  241. await server.shutdown()
  242. await client.shutdown()
  243. @pytest.mark.parametrize(
  244. "test_input,expected",
  245. [
  246. pytest.param([torch.tensor([1]), torch.tensor([2])], torch.tensor([3])),
  247. pytest.param(
  248. [torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([[1.1, 1.2], [1.3, 1.4]])],
  249. torch.tensor([[1.2, 1.4], [1.6, 1.8]]),
  250. ),
  251. ],
  252. )
  253. @pytest.mark.asyncio
  254. async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
  255. handle = handle_add_torch
  256. server = await P2P.create()
  257. await server.add_stream_handler(handler_name, handle)
  258. nodes = await bootstrap_from([server])
  259. client = await P2P.create(initial_peers=nodes)
  260. await client.wait_for_at_least_n_peers(1)
  261. inp = [serialize_torch_tensor(i).SerializeToString() for i in test_input]
  262. inp_msgp = MSGPackSerializer.dumps(inp)
  263. result_pb = await client.call_peer_handler(server.id, handler_name, inp_msgp)
  264. result = runtime_pb2.Tensor()
  265. result.ParseFromString(result_pb)
  266. result = deserialize_torch_tensor(result)
  267. assert torch.allclose(result, expected)
  268. await server.shutdown()
  269. await client.shutdown()
  270. @pytest.mark.parametrize(
  271. "replicate",
  272. [
  273. pytest.param(False, id="primary"),
  274. pytest.param(True, id="replica"),
  275. ],
  276. )
  277. @pytest.mark.asyncio
  278. async def test_call_peer_error(replicate, handler_name="handle"):
  279. server_primary = await P2P.create()
  280. server = await replicate_if_needed(server_primary, replicate)
  281. await server.add_stream_handler(handler_name, handle_add_torch_with_exc)
  282. nodes = await bootstrap_from([server])
  283. client_primary = await P2P.create(initial_peers=nodes)
  284. client = await replicate_if_needed(client_primary, replicate)
  285. await client.wait_for_at_least_n_peers(1)
  286. inp = [serialize_torch_tensor(i).SerializeToString() for i in [torch.zeros((2, 3)), torch.zeros((3, 2))]]
  287. inp_msgp = MSGPackSerializer.dumps(inp)
  288. result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
  289. assert result == b"something went wrong :("
  290. await server_primary.shutdown()
  291. await server.shutdown()
  292. await client_primary.shutdown()
  293. await client.shutdown()
  294. @pytest.mark.asyncio
  295. async def test_handlers_on_different_replicas(handler_name="handle"):
  296. def handler(arg, key):
  297. return key
  298. server_primary = await P2P.create()
  299. server_id = server_primary.id
  300. await server_primary.add_stream_handler(handler_name, partial(handler, key=b"primary"))
  301. server_replica1 = await replicate_if_needed(server_primary, True)
  302. await server_replica1.add_stream_handler(handler_name + "1", partial(handler, key=b"replica1"))
  303. server_replica2 = await replicate_if_needed(server_primary, True)
  304. await server_replica2.add_stream_handler(handler_name + "2", partial(handler, key=b"replica2"))
  305. nodes = await bootstrap_from([server_primary])
  306. client = await P2P.create(initial_peers=nodes)
  307. await client.wait_for_at_least_n_peers(1)
  308. result = await client.call_peer_handler(server_id, handler_name, b"1")
  309. assert result == b"primary"
  310. result = await client.call_peer_handler(server_id, handler_name + "1", b"2")
  311. assert result == b"replica1"
  312. result = await client.call_peer_handler(server_id, handler_name + "2", b"3")
  313. assert result == b"replica2"
  314. await server_replica1.shutdown()
  315. await server_replica2.shutdown()
  316. # Primary does not handle replicas protocols
  317. with pytest.raises(Exception):
  318. await client.call_peer_handler(server_id, handler_name + "1", b"")
  319. with pytest.raises(Exception):
  320. await client.call_peer_handler(server_id, handler_name + "2", b"")
  321. await server_primary.shutdown()
  322. await client.shutdown()