test_p2p_daemon.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. import asyncio
  2. import multiprocessing as mp
  3. import os
  4. import subprocess
  5. import tempfile
  6. from contextlib import closing
  7. from functools import partial
  8. from typing import List
  9. import numpy as np
  10. import pytest
  11. from multiaddr import Multiaddr
  12. from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
  13. from hivemind.proto import dht_pb2, test_pb2
  14. from hivemind.utils.serializer import MSGPackSerializer
  15. from test_utils.networking import get_free_port
  16. def is_process_running(pid: int) -> bool:
  17. return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
  18. async def replicate_if_needed(p2p: P2P, replicate: bool) -> P2P:
  19. return await P2P.replicate(p2p.daemon_listen_maddr) if replicate else p2p
  20. @pytest.mark.asyncio
  21. async def test_daemon_killed_on_del():
  22. p2p_daemon = await P2P.create()
  23. child_pid = p2p_daemon._child.pid
  24. assert is_process_running(child_pid)
  25. await p2p_daemon.shutdown()
  26. assert not is_process_running(child_pid)
  27. @pytest.mark.asyncio
  28. async def test_startup_error_message():
  29. with pytest.raises(P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
  30. await P2P.create(
  31. initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
  32. )
  33. with pytest.raises(P2PDaemonError, match=r"Daemon failed to start in .+ seconds"):
  34. await P2P.create(startup_timeout=0.01) # Test that startup_timeout works
  35. @pytest.mark.asyncio
  36. async def test_identity():
  37. with tempfile.TemporaryDirectory() as tempdir:
  38. id1_path = os.path.join(tempdir, "id1")
  39. id2_path = os.path.join(tempdir, "id2")
  40. p2ps = await asyncio.gather(*[P2P.create(identity_path=path) for path in [None, None, id1_path, id2_path]])
  41. # We create the second daemon with id2 separately
  42. # to avoid a race condition while saving a newly generated identity
  43. p2ps.append(await P2P.create(identity_path=id2_path))
  44. # Using the same identity (if any) should lead to the same peer ID
  45. assert p2ps[-2].peer_id == p2ps[-1].peer_id
  46. # The rest of peer IDs should be different
  47. peer_ids = {instance.peer_id for instance in p2ps}
  48. assert len(peer_ids) == 4
  49. for instance in p2ps:
  50. await instance.shutdown()
  51. with pytest.raises(FileNotFoundError, match=r"The directory.+does not exist"):
  52. P2P.generate_identity(id1_path)
  53. @pytest.mark.asyncio
  54. async def test_check_if_identity_free():
  55. with tempfile.TemporaryDirectory() as tempdir:
  56. id1_path = os.path.join(tempdir, "id1")
  57. id2_path = os.path.join(tempdir, "id2")
  58. p2ps = [await P2P.create(identity_path=id1_path)]
  59. initial_peers = await p2ps[0].get_visible_maddrs()
  60. p2ps.append(await P2P.create(initial_peers=initial_peers))
  61. p2ps.append(await P2P.create(initial_peers=initial_peers, identity_path=id2_path))
  62. with pytest.raises(P2PDaemonError, match=r"Identity.+is already taken by another peer"):
  63. await P2P.create(initial_peers=initial_peers, identity_path=id1_path)
  64. with pytest.raises(P2PDaemonError, match=r"Identity.+is already taken by another peer"):
  65. await P2P.create(initial_peers=initial_peers, identity_path=id2_path)
  66. # Must work if a P2P with a certain identity is restarted
  67. await p2ps[-1].shutdown()
  68. p2ps.pop()
  69. p2ps.append(await P2P.create(initial_peers=initial_peers, identity_path=id2_path))
  70. for instance in p2ps:
  71. await instance.shutdown()
  72. @pytest.mark.parametrize(
  73. "host_maddrs",
  74. [
  75. [Multiaddr("/ip4/127.0.0.1/tcp/0")],
  76. [Multiaddr("/ip4/127.0.0.1/udp/0/quic")],
  77. [Multiaddr("/ip4/127.0.0.1/tcp/0"), Multiaddr("/ip4/127.0.0.1/udp/0/quic")],
  78. ],
  79. )
  80. @pytest.mark.asyncio
  81. async def test_transports(host_maddrs: List[Multiaddr]):
  82. server = await P2P.create(host_maddrs=host_maddrs)
  83. peers = await server.list_peers()
  84. assert len(peers) == 0
  85. client = await P2P.create(host_maddrs=host_maddrs, initial_peers=await server.get_visible_maddrs())
  86. await client.wait_for_at_least_n_peers(1)
  87. peers = await client.list_peers()
  88. assert len({p.peer_id for p in peers}) == 1
  89. peers = await server.list_peers()
  90. assert len({p.peer_id for p in peers}) == 1
  91. @pytest.mark.asyncio
  92. async def test_daemon_replica_does_not_affect_primary():
  93. p2p_daemon = await P2P.create()
  94. p2p_replica = await P2P.replicate(p2p_daemon.daemon_listen_maddr)
  95. child_pid = p2p_daemon._child.pid
  96. assert is_process_running(child_pid)
  97. await p2p_replica.shutdown()
  98. assert is_process_running(child_pid)
  99. await p2p_daemon.shutdown()
  100. assert not is_process_running(child_pid)
  101. @pytest.mark.asyncio
  102. async def test_unary_handler_edge_cases():
  103. p2p = await P2P.create()
  104. p2p_replica = await P2P.replicate(p2p.daemon_listen_maddr)
  105. async def square_handler(data: test_pb2.TestRequest, context):
  106. return test_pb2.TestResponse(number=data.number**2)
  107. await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
  108. # try adding a duplicate handler
  109. with pytest.raises(P2PDaemonError):
  110. await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
  111. # try adding a duplicate handler from replicated p2p
  112. with pytest.raises(P2PDaemonError):
  113. await p2p_replica.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
  114. # try dialing yourself
  115. with pytest.raises(P2PDaemonError):
  116. await p2p_replica.call_protobuf_handler(
  117. p2p.peer_id, "square", test_pb2.TestRequest(number=41), test_pb2.TestResponse
  118. )
  119. @pytest.mark.parametrize(
  120. "should_cancel,replicate",
  121. [
  122. (True, False),
  123. (True, True),
  124. (False, False),
  125. (False, True),
  126. ],
  127. )
  128. @pytest.mark.asyncio
  129. async def test_call_protobuf_handler(should_cancel, replicate, handle_name="handle"):
  130. handler_cancelled = False
  131. server_primary = await P2P.create()
  132. server = await replicate_if_needed(server_primary, replicate)
  133. async def ping_handler(request, context):
  134. try:
  135. await asyncio.sleep(2)
  136. except asyncio.CancelledError:
  137. nonlocal handler_cancelled
  138. handler_cancelled = True
  139. return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.peer_id.to_bytes()), available=True)
  140. server_pid = server_primary._child.pid
  141. await server.add_protobuf_handler(handle_name, ping_handler, dht_pb2.PingRequest)
  142. assert is_process_running(server_pid)
  143. client_primary = await P2P.create(initial_peers=await server.get_visible_maddrs())
  144. client = await replicate_if_needed(client_primary, replicate)
  145. client_pid = client_primary._child.pid
  146. assert is_process_running(client_pid)
  147. await client.wait_for_at_least_n_peers(1)
  148. ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.peer_id.to_bytes()), validate=True)
  149. expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.peer_id.to_bytes()), available=True)
  150. if should_cancel:
  151. call_task = asyncio.create_task(
  152. client.call_protobuf_handler(server.peer_id, handle_name, ping_request, dht_pb2.PingResponse)
  153. )
  154. await asyncio.sleep(0.25)
  155. call_task.cancel()
  156. await asyncio.sleep(0.25)
  157. assert handler_cancelled
  158. else:
  159. actual_response = await client.call_protobuf_handler(
  160. server.peer_id, handle_name, ping_request, dht_pb2.PingResponse
  161. )
  162. assert actual_response == expected_response
  163. assert not handler_cancelled
  164. await server.shutdown()
  165. await server_primary.shutdown()
  166. assert not is_process_running(server_pid)
  167. await client_primary.shutdown()
  168. assert not is_process_running(client_pid)
  169. @pytest.mark.asyncio
  170. async def test_call_protobuf_handler_error(handle_name="handle"):
  171. async def error_handler(request, context):
  172. raise ValueError("boom")
  173. server = await P2P.create()
  174. server_pid = server._child.pid
  175. await server.add_protobuf_handler(handle_name, error_handler, dht_pb2.PingRequest)
  176. assert is_process_running(server_pid)
  177. client = await P2P.create(initial_peers=await server.get_visible_maddrs())
  178. client_pid = client._child.pid
  179. assert is_process_running(client_pid)
  180. await client.wait_for_at_least_n_peers(1)
  181. ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.peer_id.to_bytes()), validate=True)
  182. with pytest.raises(P2PHandlerError) as excinfo:
  183. await client.call_protobuf_handler(server.peer_id, handle_name, ping_request, dht_pb2.PingResponse)
  184. assert "boom" in str(excinfo.value)
  185. await server.shutdown()
  186. await client.shutdown()
  187. async def handle_square_stream(_, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  188. with closing(writer):
  189. while True:
  190. try:
  191. x = MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
  192. except asyncio.IncompleteReadError:
  193. break
  194. result = x**2
  195. await P2P.send_raw_data(MSGPackSerializer.dumps(result), writer)
  196. async def validate_square_stream(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  197. with closing(writer):
  198. for _ in range(10):
  199. x = np.random.randint(100)
  200. await P2P.send_raw_data(MSGPackSerializer.dumps(x), writer)
  201. result = MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
  202. assert result == x**2
  203. @pytest.mark.asyncio
  204. async def test_call_peer_single_process():
  205. server = await P2P.create()
  206. server_pid = server._child.pid
  207. assert is_process_running(server_pid)
  208. handler_name = "square"
  209. await server.add_binary_stream_handler(handler_name, handle_square_stream)
  210. client = await P2P.create(initial_peers=await server.get_visible_maddrs())
  211. client_pid = client._child.pid
  212. assert is_process_running(client_pid)
  213. await client.wait_for_at_least_n_peers(1)
  214. _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
  215. await validate_square_stream(reader, writer)
  216. await server.shutdown()
  217. assert not is_process_running(server_pid)
  218. await client.shutdown()
  219. assert not is_process_running(client_pid)
  220. async def run_server(handler_name, server_side, response_received):
  221. server = await P2P.create()
  222. server_pid = server._child.pid
  223. assert is_process_running(server_pid)
  224. await server.add_binary_stream_handler(handler_name, handle_square_stream)
  225. server_side.send(server.peer_id)
  226. server_side.send(await server.get_visible_maddrs())
  227. while response_received.value == 0:
  228. await asyncio.sleep(0.5)
  229. await server.shutdown()
  230. assert not is_process_running(server_pid)
  231. def server_target(handler_name, server_side, response_received):
  232. asyncio.run(run_server(handler_name, server_side, response_received))
  233. @pytest.mark.asyncio
  234. async def test_call_peer_different_processes():
  235. handler_name = "square"
  236. server_side, client_side = mp.Pipe()
  237. response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
  238. response_received.value = 0
  239. proc = mp.Process(target=server_target, args=(handler_name, server_side, response_received))
  240. proc.start()
  241. peer_id = client_side.recv()
  242. peer_maddrs = client_side.recv()
  243. client = await P2P.create(initial_peers=peer_maddrs)
  244. client_pid = client._child.pid
  245. assert is_process_running(client_pid)
  246. await client.wait_for_at_least_n_peers(1)
  247. _, reader, writer = await client.call_binary_stream_handler(peer_id, handler_name)
  248. await validate_square_stream(reader, writer)
  249. response_received.value = 1
  250. await client.shutdown()
  251. assert not is_process_running(client_pid)
  252. proc.join()
  253. assert proc.exitcode == 0
  254. @pytest.mark.asyncio
  255. async def test_error_closes_connection():
  256. async def handle_raising_error(_, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  257. with closing(writer):
  258. command = await P2P.receive_raw_data(reader)
  259. if command == b"raise_error":
  260. raise Exception("The handler has failed")
  261. else:
  262. await P2P.send_raw_data(b"okay", writer)
  263. server = await P2P.create()
  264. server_pid = server._child.pid
  265. assert is_process_running(server_pid)
  266. handler_name = "handler"
  267. await server.add_binary_stream_handler(handler_name, handle_raising_error)
  268. client = await P2P.create(initial_peers=await server.get_visible_maddrs())
  269. client_pid = client._child.pid
  270. assert is_process_running(client_pid)
  271. await client.wait_for_at_least_n_peers(1)
  272. _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
  273. with closing(writer):
  274. await P2P.send_raw_data(b"raise_error", writer)
  275. with pytest.raises(asyncio.IncompleteReadError): # Means that the connection is closed
  276. await P2P.receive_raw_data(reader)
  277. # Despite the handler raised an exception, the server did not crash and ready for next requests
  278. assert is_process_running(server_pid)
  279. _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
  280. with closing(writer):
  281. await P2P.send_raw_data(b"behave_normally", writer)
  282. assert await P2P.receive_raw_data(reader) == b"okay"
  283. await server.shutdown()
  284. assert not is_process_running(server_pid)
  285. await client.shutdown()
  286. assert not is_process_running(client_pid)
  287. @pytest.mark.asyncio
  288. async def test_handlers_on_different_replicas():
  289. async def handler(_, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, key: str) -> None:
  290. with closing(writer):
  291. await P2P.send_raw_data(key, writer)
  292. server_primary = await P2P.create()
  293. server_id = server_primary.peer_id
  294. await server_primary.add_binary_stream_handler("handle_primary", partial(handler, key=b"primary"))
  295. server_replica1 = await replicate_if_needed(server_primary, True)
  296. await server_replica1.add_binary_stream_handler("handle1", partial(handler, key=b"replica1"))
  297. server_replica2 = await replicate_if_needed(server_primary, True)
  298. await server_replica2.add_binary_stream_handler("handle2", partial(handler, key=b"replica2"))
  299. client = await P2P.create(initial_peers=await server_primary.get_visible_maddrs())
  300. await client.wait_for_at_least_n_peers(1)
  301. for name, expected_key in [("handle_primary", b"primary"), ("handle1", b"replica1"), ("handle2", b"replica2")]:
  302. _, reader, writer = await client.call_binary_stream_handler(server_id, name)
  303. with closing(writer):
  304. assert await P2P.receive_raw_data(reader) == expected_key
  305. await server_replica1.shutdown()
  306. await server_replica2.shutdown()
  307. # Primary does not handle replicas protocols after their shutdown
  308. for name in ["handle1", "handle2"]:
  309. _, reader, writer = await client.call_binary_stream_handler(server_id, name)
  310. with pytest.raises(asyncio.IncompleteReadError), closing(writer):
  311. await P2P.receive_raw_data(reader)
  312. await server_primary.shutdown()
  313. await client.shutdown()