test_p2p_daemon_bindings.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. import asyncio
  2. import io
  3. from contextlib import AsyncExitStack
  4. import pytest
  5. from google.protobuf.message import EncodeError
  6. from multiaddr import Multiaddr, protocols
  7. from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, parse_conn_protocol
  8. from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
  9. from hivemind.p2p.p2p_daemon_bindings.utils import (
  10. ControlFailure,
  11. raise_if_failed,
  12. read_pbmsg_safe,
  13. read_unsigned_varint,
  14. write_pbmsg,
  15. write_unsigned_varint,
  16. )
  17. from hivemind.proto import p2pd_pb2 as p2pd_pb
  18. from test_utils.p2p_daemon import connect_safe, make_p2pd_pair_unix
  19. def test_raise_if_failed_raises():
  20. resp = p2pd_pb.Response()
  21. resp.type = p2pd_pb.Response.ERROR
  22. with pytest.raises(ControlFailure):
  23. raise_if_failed(resp)
  24. def test_raise_if_failed_not_raises():
  25. resp = p2pd_pb.Response()
  26. resp.type = p2pd_pb.Response.OK
  27. raise_if_failed(resp)
  28. PAIRS_INT_SERIALIZED_VALID = (
  29. (0, b"\x00"),
  30. (1, b"\x01"),
  31. (128, b"\x80\x01"),
  32. (2**32, b"\x80\x80\x80\x80\x10"),
  33. (2**64 - 1, b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"),
  34. )
  35. PAIRS_INT_SERIALIZED_OVERFLOW = (
  36. (2**64, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
  37. (2**64 + 1, b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
  38. (
  39. 2**128,
  40. b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x04",
  41. ),
  42. )
  43. PEER_ID_STRING = "QmS5QmciTXXnCUCyxud5eWFenUMAmvAWSDa1c7dvdXRMZ7"
  44. PEER_ID_BYTES = b'\x12 7\x87F.[\xb5\xb1o\xe5*\xc7\xb9\xbb\x11:"Z|j2\x8ad\x1b\xa6\xe5<Ip\xfe\xb4\xf5v'
  45. PEER_ID = PeerID(PEER_ID_BYTES)
  46. MADDR = Multiaddr("/unix/123")
  47. NUM_P2PDS = 4
  48. PEER_ID_RANDOM = PeerID.from_base58("QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNK1")
  49. ENABLE_CONTROL = True
  50. ENABLE_CONNMGR = False
  51. ENABLE_DHT = False
  52. ENABLE_PUBSUB = False
  53. FUNC_MAKE_P2PD_PAIR = make_p2pd_pair_unix
  54. class MockReader(io.BytesIO):
  55. async def readexactly(self, n):
  56. await asyncio.sleep(0)
  57. return self.read(n)
  58. class MockWriter(io.BytesIO):
  59. pass
  60. class MockReaderWriter(MockReader, MockWriter):
  61. pass
  62. @pytest.mark.parametrize("integer, serialized_integer", PAIRS_INT_SERIALIZED_VALID)
  63. @pytest.mark.asyncio
  64. async def test_write_unsigned_varint(integer, serialized_integer):
  65. s = MockWriter()
  66. await write_unsigned_varint(s, integer)
  67. assert s.getvalue() == serialized_integer
  68. @pytest.mark.parametrize("integer", tuple(i[0] for i in PAIRS_INT_SERIALIZED_OVERFLOW))
  69. @pytest.mark.asyncio
  70. async def test_write_unsigned_varint_overflow(integer):
  71. s = MockWriter()
  72. with pytest.raises(ValueError):
  73. await write_unsigned_varint(s, integer)
  74. @pytest.mark.parametrize("integer", (-1, -(2**32), -(2**64), -(2**128)))
  75. @pytest.mark.asyncio
  76. async def test_write_unsigned_varint_negative(integer):
  77. s = MockWriter()
  78. with pytest.raises(ValueError):
  79. await write_unsigned_varint(s, integer)
  80. @pytest.mark.parametrize("integer, serialized_integer", PAIRS_INT_SERIALIZED_VALID)
  81. @pytest.mark.asyncio
  82. async def test_read_unsigned_varint(integer, serialized_integer):
  83. s = MockReader(serialized_integer)
  84. result = await read_unsigned_varint(s)
  85. assert result == integer
  86. @pytest.mark.parametrize("serialized_integer", tuple(i[1] for i in PAIRS_INT_SERIALIZED_OVERFLOW))
  87. @pytest.mark.asyncio
  88. async def test_read_unsigned_varint_overflow(serialized_integer):
  89. s = MockReader(serialized_integer)
  90. with pytest.raises(ValueError):
  91. await read_unsigned_varint(s)
  92. @pytest.mark.parametrize("max_bits", (2, 31, 32, 63, 64, 127, 128))
  93. @pytest.mark.asyncio
  94. async def test_read_write_unsigned_varint_max_bits_edge(max_bits):
  95. """
  96. Test edge cases with different `max_bits`
  97. """
  98. for i in range(-3, 0):
  99. integer = i + 2**max_bits
  100. s = MockReaderWriter()
  101. await write_unsigned_varint(s, integer, max_bits=max_bits)
  102. s.seek(0, 0)
  103. result = await read_unsigned_varint(s, max_bits=max_bits)
  104. assert integer == result
  105. def test_peer_id():
  106. assert PEER_ID.to_bytes() == PEER_ID_BYTES
  107. assert PEER_ID.to_string() == PEER_ID_STRING
  108. peer_id_2 = PeerID.from_base58(PEER_ID_STRING)
  109. assert peer_id_2.to_bytes() == PEER_ID_BYTES
  110. assert peer_id_2.to_string() == PEER_ID_STRING
  111. assert PEER_ID == peer_id_2
  112. peer_id_3 = PeerID.from_base58("QmbmfNDEth7Ucvjuxiw3SP3E4PoJzbk7g4Ge6ZDigbCsNp")
  113. assert PEER_ID != peer_id_3
  114. a = PeerID.from_base58("bob")
  115. b = PeerID.from_base58("eve")
  116. assert a < b and b > a and not (b < a) and not (a > b)
  117. with pytest.raises(TypeError):
  118. assert a < object()
  119. def test_stream_info():
  120. proto = "123"
  121. si = StreamInfo(PEER_ID, MADDR, proto)
  122. assert si.peer_id == PEER_ID
  123. assert si.addr == MADDR
  124. assert si.proto == proto
  125. pb_si = si.to_protobuf()
  126. assert pb_si.peer == PEER_ID.to_bytes()
  127. assert pb_si.addr == MADDR.to_bytes()
  128. assert pb_si.proto == si.proto
  129. si_1 = StreamInfo.from_protobuf(pb_si)
  130. assert si_1.peer_id == PEER_ID
  131. assert si_1.addr == MADDR
  132. assert si_1.proto == proto
  133. def test_peer_info():
  134. pi = PeerInfo(PEER_ID, [MADDR])
  135. assert pi.peer_id == PEER_ID
  136. assert pi.addrs == [MADDR]
  137. pi_pb = p2pd_pb.PeerInfo(id=PEER_ID.to_bytes(), addrs=[MADDR.to_bytes()])
  138. pi_1 = PeerInfo.from_protobuf(pi_pb)
  139. assert pi.peer_id == pi_1.peer_id
  140. assert pi.addrs == pi_1.addrs
  141. @pytest.mark.parametrize(
  142. "maddr_str, expected_proto",
  143. (("/unix/123", protocols.P_UNIX), ("/ip4/127.0.0.1/tcp/7777", protocols.P_IP4)),
  144. )
  145. def test_parse_conn_protocol_valid(maddr_str, expected_proto):
  146. assert parse_conn_protocol(Multiaddr(maddr_str)) == expected_proto
  147. @pytest.mark.parametrize(
  148. "maddr_str",
  149. (
  150. "/p2p/QmbHVEEepCi7rn7VL7Exxpd2Ci9NNB6ifvqwhsrbRMgQFP",
  151. "/onion/timaq4ygg2iegci7:1234",
  152. ),
  153. )
  154. def test_parse_conn_protocol_invalid(maddr_str):
  155. maddr = Multiaddr(maddr_str)
  156. with pytest.raises(ValueError):
  157. parse_conn_protocol(maddr)
  158. @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
  159. @pytest.mark.asyncio
  160. async def test_client_create_control_maddr(control_maddr_str):
  161. c = DaemonConnector(Multiaddr(control_maddr_str))
  162. assert c.control_maddr == Multiaddr(control_maddr_str)
  163. def test_client_create_default_control_maddr():
  164. c = DaemonConnector()
  165. assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
  166. @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
  167. @pytest.mark.asyncio
  168. async def test_control_client_create_listen_maddr(listen_maddr_str):
  169. c = await ControlClient.create(
  170. daemon_connector=DaemonConnector(),
  171. listen_maddr=Multiaddr(listen_maddr_str),
  172. use_persistent_conn=False,
  173. )
  174. assert c.listen_maddr == Multiaddr(listen_maddr_str)
  175. @pytest.mark.asyncio
  176. async def test_control_client_create_default_listen_maddr():
  177. c = await ControlClient.create(daemon_connector=DaemonConnector(), use_persistent_conn=False)
  178. assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
  179. @pytest.mark.parametrize(
  180. "msg_bytes",
  181. (
  182. p2pd_pb.Response(
  183. type=p2pd_pb.Response.Type.OK,
  184. identify=p2pd_pb.IdentifyResponse(
  185. id=PeerID.from_base58("QmT7WhTne9zBLfAgAJt9aiZ8jZ5BxJGowRubxsHYmnyzUd").to_bytes(),
  186. addrs=[
  187. Multiaddr("/p2p-circuit").to_bytes(),
  188. Multiaddr("/ip4/127.0.0.1/tcp/51126").to_bytes(),
  189. Multiaddr("/ip4/192.168.10.135/tcp/51126").to_bytes(),
  190. Multiaddr("/ip6/::1/tcp/51127").to_bytes(),
  191. ],
  192. ),
  193. ).SerializeToString(),
  194. p2pd_pb.Response(
  195. type=p2pd_pb.Response.Type.OK,
  196. identify=p2pd_pb.IdentifyResponse(
  197. id=PeerID.from_base58("QmcQFt2MFfCZ9AxzUCNrk4k7TtMdZZvAAteaA6tHpBKdrk").to_bytes(),
  198. addrs=[
  199. Multiaddr("/p2p-circuit").to_bytes(),
  200. Multiaddr("/ip4/127.0.0.1/tcp/51493").to_bytes(),
  201. Multiaddr("/ip4/192.168.10.135/tcp/51493").to_bytes(),
  202. Multiaddr("/ip6/::1/tcp/51494").to_bytes(),
  203. ],
  204. ),
  205. ).SerializeToString(),
  206. p2pd_pb.Response(
  207. type=p2pd_pb.Response.Type.OK,
  208. identify=p2pd_pb.IdentifyResponse(
  209. id=PeerID.from_base58("QmbWqVVoz7v9LS9ZUQAhyyfdFJY3iU8ZrUY3XQozoTA5cc").to_bytes(),
  210. addrs=[
  211. Multiaddr("/p2p-circuit").to_bytes(),
  212. Multiaddr("/ip4/127.0.0.1/tcp/51552").to_bytes(),
  213. Multiaddr("/ip4/192.168.10.135/tcp/51552").to_bytes(),
  214. Multiaddr("/ip6/::1/tcp/51553").to_bytes(),
  215. ],
  216. ),
  217. ).SerializeToString(),
  218. ),
  219. # give test cases ids to prevent bytes from ruining the terminal
  220. ids=("pb example Response 0", "pb example Response 1", "pb example Response 2"),
  221. )
  222. @pytest.mark.asyncio
  223. async def test_read_pbmsg_safe_valid(msg_bytes):
  224. s = MockReaderWriter()
  225. await write_unsigned_varint(s, len(msg_bytes))
  226. s.write(msg_bytes)
  227. # reset the offset back to the beginning
  228. s.seek(0, 0)
  229. pb_msg = p2pd_pb.Response()
  230. await read_pbmsg_safe(s, pb_msg)
  231. assert pb_msg.SerializeToString() == msg_bytes
  232. @pytest.mark.parametrize(
  233. "pb_type, pb_msg",
  234. (
  235. (
  236. p2pd_pb.Response,
  237. p2pd_pb.Response(
  238. type=p2pd_pb.Response.Type.OK,
  239. dht=p2pd_pb.DHTResponse(
  240. type=p2pd_pb.DHTResponse.Type.VALUE,
  241. peer=p2pd_pb.PeerInfo(
  242. id=PeerID.from_base58("QmNaXUy78W9moQ9APCoKaTtPjLcEJPN9hRBCqErY7o2fQs").to_bytes(),
  243. addrs=[
  244. Multiaddr("/p2p-circuit").to_bytes(),
  245. Multiaddr("/ip4/127.0.0.1/tcp/56929").to_bytes(),
  246. Multiaddr("/ip4/192.168.10.135/tcp/56929").to_bytes(),
  247. Multiaddr("/ip6/::1/tcp/56930").to_bytes(),
  248. ],
  249. ),
  250. ),
  251. ),
  252. ),
  253. (p2pd_pb.Request, p2pd_pb.Request(type=p2pd_pb.Request.Type.LIST_PEERS)),
  254. (
  255. p2pd_pb.DHTRequest,
  256. p2pd_pb.DHTRequest(
  257. type=p2pd_pb.DHTRequest.Type.FIND_PEER,
  258. peer=PeerID.from_base58("QmcgHMuEhqdLHDVeNjiCGU7Ds6E7xK3f4amgiwHNPKKn7R").to_bytes(),
  259. ),
  260. ),
  261. (
  262. p2pd_pb.DHTResponse,
  263. p2pd_pb.DHTResponse(
  264. type=p2pd_pb.DHTResponse.Type.VALUE,
  265. peer=p2pd_pb.PeerInfo(
  266. id=PeerID.from_base58("QmWP32GhEyXVQsLXFvV81eadDC8zQRZxZvJK359rXxLquk").to_bytes(),
  267. addrs=[
  268. Multiaddr("/p2p-circuit").to_bytes(),
  269. Multiaddr("/ip4/127.0.0.1/tcp/56897").to_bytes(),
  270. Multiaddr("/ip4/192.168.10.135/tcp/56897").to_bytes(),
  271. Multiaddr("/ip6/::1/tcp/56898").to_bytes(),
  272. ],
  273. ),
  274. ),
  275. ),
  276. (
  277. p2pd_pb.StreamInfo,
  278. p2pd_pb.StreamInfo(
  279. peer=PeerID.from_base58("QmewLxB46MftfxQiunRgJo2W8nW4Lh5NLEkRohkHhJ4wW6").to_bytes(),
  280. addr=Multiaddr("/ip4/127.0.0.1/tcp/57029").to_bytes(),
  281. proto=b"protocol123",
  282. ),
  283. ),
  284. ),
  285. ids=(
  286. "pb example Response",
  287. "pb example Request",
  288. "pb example DHTRequest",
  289. "pb example DHTResponse",
  290. "pb example StreamInfo",
  291. ),
  292. )
  293. @pytest.mark.asyncio
  294. async def test_write_pbmsg(pb_type, pb_msg):
  295. msg_bytes = bytes(chr(pb_msg.ByteSize()), "utf-8") + pb_msg.SerializeToString()
  296. pb_obj = pb_type()
  297. s_read = MockReaderWriter(msg_bytes)
  298. await read_pbmsg_safe(s_read, pb_obj)
  299. s_write = MockReaderWriter()
  300. await write_pbmsg(s_write, pb_obj)
  301. assert msg_bytes == s_write.getvalue()
  302. @pytest.mark.parametrize(
  303. "pb_msg",
  304. (
  305. p2pd_pb.Response(),
  306. p2pd_pb.Request(),
  307. p2pd_pb.DHTRequest(),
  308. p2pd_pb.DHTResponse(),
  309. p2pd_pb.StreamInfo(),
  310. ),
  311. )
  312. @pytest.mark.asyncio
  313. async def test_write_pbmsg_missing_fields(pb_msg):
  314. with pytest.raises(EncodeError):
  315. await write_pbmsg(MockReaderWriter(), pb_msg)
  316. @pytest.fixture
  317. async def p2pcs():
  318. # TODO: Change back to gather style
  319. async with AsyncExitStack() as stack:
  320. p2pd_tuples = [
  321. await stack.enter_async_context(
  322. FUNC_MAKE_P2PD_PAIR(
  323. enable_control=ENABLE_CONTROL,
  324. enable_connmgr=ENABLE_CONNMGR,
  325. enable_dht=ENABLE_DHT,
  326. enable_pubsub=ENABLE_PUBSUB,
  327. )
  328. )
  329. for _ in range(NUM_P2PDS)
  330. ]
  331. yield tuple(p2pd_tuple.client for p2pd_tuple in p2pd_tuples)
  332. @pytest.mark.asyncio
  333. async def test_client_identify(p2pcs):
  334. await p2pcs[0].identify()
  335. @pytest.mark.asyncio
  336. async def test_client_connect_success(p2pcs):
  337. peer_id_0, maddrs_0 = await p2pcs[0].identify()
  338. peer_id_1, maddrs_1 = await p2pcs[1].identify()
  339. await p2pcs[0].connect(peer_id_1, maddrs_1)
  340. # test case: repeated connections
  341. await p2pcs[1].connect(peer_id_0, maddrs_0)
  342. @pytest.mark.asyncio
  343. async def test_client_connect_failure(p2pcs):
  344. peer_id_1, maddrs_1 = await p2pcs[1].identify()
  345. await p2pcs[0].identify()
  346. # test case: `peer_id` mismatches
  347. with pytest.raises(ControlFailure):
  348. await p2pcs[0].connect(PEER_ID_RANDOM, maddrs_1)
  349. # test case: empty maddrs
  350. with pytest.raises(ControlFailure):
  351. await p2pcs[0].connect(peer_id_1, [])
  352. # test case: wrong maddrs
  353. with pytest.raises(ControlFailure):
  354. await p2pcs[0].connect(peer_id_1, [Multiaddr("/ip4/127.0.0.1/udp/0")])
  355. @pytest.mark.asyncio
  356. async def test_connect_safe(p2pcs):
  357. await connect_safe(p2pcs[0], p2pcs[1])
  358. @pytest.mark.asyncio
  359. async def test_client_list_peers(p2pcs):
  360. # test case: no peers
  361. assert len(await p2pcs[0].list_peers()) == 0
  362. # test case: 1 peer
  363. await connect_safe(p2pcs[0], p2pcs[1])
  364. assert len(await p2pcs[0].list_peers()) == 1
  365. assert len(await p2pcs[1].list_peers()) == 1
  366. # test case: one more peer
  367. await connect_safe(p2pcs[0], p2pcs[2])
  368. assert len(await p2pcs[0].list_peers()) == 2
  369. assert len(await p2pcs[1].list_peers()) == 1
  370. assert len(await p2pcs[2].list_peers()) == 1
  371. @pytest.mark.asyncio
  372. async def test_client_disconnect(p2pcs):
  373. # test case: disconnect a peer without connections
  374. await p2pcs[1].disconnect(PEER_ID_RANDOM)
  375. # test case: disconnect
  376. peer_id_0, _ = await p2pcs[0].identify()
  377. await connect_safe(p2pcs[0], p2pcs[1])
  378. assert len(await p2pcs[0].list_peers()) == 1
  379. assert len(await p2pcs[1].list_peers()) == 1
  380. await p2pcs[1].disconnect(peer_id_0)
  381. assert len(await p2pcs[0].list_peers()) == 0
  382. assert len(await p2pcs[1].list_peers()) == 0
  383. # test case: disconnect twice
  384. await p2pcs[1].disconnect(peer_id_0)
  385. assert len(await p2pcs[0].list_peers()) == 0
  386. assert len(await p2pcs[1].list_peers()) == 0
  387. @pytest.mark.asyncio
  388. async def test_client_stream_open_success(p2pcs):
  389. peer_id_1, maddrs_1 = await p2pcs[1].identify()
  390. await connect_safe(p2pcs[0], p2pcs[1])
  391. proto = "123"
  392. async def handle_proto(stream_info, reader, writer):
  393. await reader.readexactly(1)
  394. await p2pcs[1].stream_handler(proto, handle_proto)
  395. # test case: normal
  396. stream_info, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
  397. assert stream_info.peer_id == peer_id_1
  398. assert stream_info.addr in maddrs_1
  399. assert stream_info.proto == "123"
  400. writer.close()
  401. # test case: open with multiple protocols
  402. stream_info, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto, "another_protocol"))
  403. assert stream_info.peer_id == peer_id_1
  404. assert stream_info.addr in maddrs_1
  405. assert stream_info.proto == "123"
  406. writer.close()
  407. @pytest.mark.asyncio
  408. async def test_client_stream_open_failure(p2pcs):
  409. peer_id_1, _ = await p2pcs[1].identify()
  410. await connect_safe(p2pcs[0], p2pcs[1])
  411. proto = "123"
  412. # test case: `stream_open` to a peer who didn't register the protocol
  413. with pytest.raises(ControlFailure):
  414. await p2pcs[0].stream_open(peer_id_1, (proto,))
  415. # test case: `stream_open` to a peer for a non-registered protocol
  416. async def handle_proto(stream_info, reader, writer):
  417. pass
  418. await p2pcs[1].stream_handler(proto, handle_proto)
  419. with pytest.raises(ControlFailure):
  420. await p2pcs[0].stream_open(peer_id_1, ("another_protocol",))
  421. @pytest.mark.asyncio
  422. async def test_client_stream_handler_success(p2pcs):
  423. peer_id_1, _ = await p2pcs[1].identify()
  424. await connect_safe(p2pcs[0], p2pcs[1])
  425. proto = "protocol123"
  426. bytes_to_send = b"yoyoyoyoyog"
  427. # event for this test function to wait until the handler function receiving the incoming data
  428. event_handler_finished = asyncio.Event()
  429. async def handle_proto(stream_info, reader, writer):
  430. nonlocal event_handler_finished
  431. bytes_received = await reader.readexactly(len(bytes_to_send))
  432. assert bytes_received == bytes_to_send
  433. event_handler_finished.set()
  434. await p2pcs[1].stream_handler(proto, handle_proto)
  435. assert proto in p2pcs[1].control.handlers
  436. assert handle_proto == p2pcs[1].control.handlers[proto]
  437. # test case: test the stream handler `handle_proto`
  438. _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
  439. # wait until the handler function starts blocking waiting for the data
  440. # because we haven't sent the data, we know the handler function must still blocking waiting.
  441. # get the task of the protocol handler
  442. writer.write(bytes_to_send)
  443. # wait for the handler to finish
  444. writer.close()
  445. await event_handler_finished.wait()
  446. # test case: two streams to different handlers respectively
  447. another_proto = "another_protocol123"
  448. another_bytes_to_send = b"456"
  449. event_another_proto = asyncio.Event()
  450. async def handle_another_proto(stream_info, reader, writer):
  451. event_another_proto.set()
  452. bytes_received = await reader.readexactly(len(another_bytes_to_send))
  453. assert bytes_received == another_bytes_to_send
  454. await p2pcs[1].stream_handler(another_proto, handle_another_proto)
  455. assert another_proto in p2pcs[1].control.handlers
  456. assert handle_another_proto == p2pcs[1].control.handlers[another_proto]
  457. _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (another_proto,))
  458. await event_another_proto.wait()
  459. # we know at this moment the handler must still blocking wait
  460. writer.write(another_bytes_to_send)
  461. writer.close()
  462. # test case: registering twice can override the previous registration
  463. event_third = asyncio.Event()
  464. async def handler_third(stream_info, reader, writer):
  465. event_third.set()
  466. await p2pcs[1].stream_handler(another_proto, handler_third)
  467. assert another_proto in p2pcs[1].control.handlers
  468. # ensure the handler is override
  469. assert handler_third == p2pcs[1].control.handlers[another_proto]
  470. await p2pcs[0].stream_open(peer_id_1, (another_proto,))
  471. # ensure the overriding handler is called when the protocol is opened a stream
  472. await event_third.wait()
  473. @pytest.mark.asyncio
  474. async def test_client_stream_handler_failure(p2pcs):
  475. peer_id_1, _ = await p2pcs[1].identify()
  476. await connect_safe(p2pcs[0], p2pcs[1])
  477. proto = "123"
  478. # test case: registered a wrong protocol name
  479. async def handle_proto_correct_params(stream_info, stream):
  480. pass
  481. await p2pcs[1].stream_handler("another_protocol", handle_proto_correct_params)
  482. with pytest.raises(ControlFailure):
  483. await p2pcs[0].stream_open(peer_id_1, (proto,))