test_p2p_daemon_bindings.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559
  1. import asyncio
  2. import io
  3. from contextlib import AsyncExitStack
  4. import pytest
  5. from google.protobuf.message import EncodeError
  6. from multiaddr import Multiaddr, protocols
  7. from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, parse_conn_protocol
  8. from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
  9. from hivemind.p2p.p2p_daemon_bindings.utils import (ControlFailure, raise_if_failed, read_pbmsg_safe,
  10. read_unsigned_varint, write_pbmsg, write_unsigned_varint)
  11. from hivemind.proto import p2pd_pb2 as p2pd_pb
  12. from test_utils.p2p_daemon import make_p2pd_pair_ip4, connect_safe
  13. def test_raise_if_failed_raises():
  14. resp = p2pd_pb.Response()
  15. resp.type = p2pd_pb.Response.ERROR
  16. with pytest.raises(ControlFailure):
  17. raise_if_failed(resp)
  18. def test_raise_if_failed_not_raises():
  19. resp = p2pd_pb.Response()
  20. resp.type = p2pd_pb.Response.OK
  21. raise_if_failed(resp)
  22. PAIRS_INT_SERIALIZED_VALID = (
  23. (0, b"\x00"),
  24. (1, b"\x01"),
  25. (128, b"\x80\x01"),
  26. (2 ** 32, b"\x80\x80\x80\x80\x10"),
  27. (2 ** 64 - 1, b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"),
  28. )
  29. PAIRS_INT_SERIALIZED_OVERFLOW = (
  30. (2 ** 64, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
  31. (2 ** 64 + 1, b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
  32. (
  33. 2 ** 128,
  34. b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x04",
  35. ),
  36. )
  37. PEER_ID_STRING = "QmS5QmciTXXnCUCyxud5eWFenUMAmvAWSDa1c7dvdXRMZ7"
  38. PEER_ID_BYTES = b'\x12 7\x87F.[\xb5\xb1o\xe5*\xc7\xb9\xbb\x11:"Z|j2\x8ad\x1b\xa6\xe5<Ip\xfe\xb4\xf5v'
  39. PEER_ID = PeerID(PEER_ID_BYTES)
  40. MADDR = Multiaddr("/unix/123")
  41. NUM_P2PDS = 4
  42. PEER_ID_RANDOM = PeerID.from_base58("QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNK1")
  43. ENABLE_CONTROL = True
  44. ENABLE_CONNMGR = False
  45. ENABLE_DHT = False
  46. ENABLE_PUBSUB = False
  47. FUNC_MAKE_P2PD_PAIR = make_p2pd_pair_ip4
  48. class MockReader(io.BytesIO):
  49. async def readexactly(self, n):
  50. await asyncio.sleep(0)
  51. return self.read(n)
  52. class MockWriter(io.BytesIO):
  53. pass
  54. class MockReaderWriter(MockReader, MockWriter):
  55. pass
  56. @pytest.mark.parametrize("integer, serialized_integer", PAIRS_INT_SERIALIZED_VALID)
  57. @pytest.mark.asyncio
  58. async def test_write_unsigned_varint(integer, serialized_integer):
  59. s = MockWriter()
  60. await write_unsigned_varint(s, integer)
  61. assert s.getvalue() == serialized_integer
  62. @pytest.mark.parametrize("integer", tuple(i[0] for i in PAIRS_INT_SERIALIZED_OVERFLOW))
  63. @pytest.mark.asyncio
  64. async def test_write_unsigned_varint_overflow(integer):
  65. s = MockWriter()
  66. with pytest.raises(ValueError):
  67. await write_unsigned_varint(s, integer)
  68. @pytest.mark.parametrize("integer", (-1, -(2 ** 32), -(2 ** 64), -(2 ** 128)))
  69. @pytest.mark.asyncio
  70. async def test_write_unsigned_varint_negative(integer):
  71. s = MockWriter()
  72. with pytest.raises(ValueError):
  73. await write_unsigned_varint(s, integer)
  74. @pytest.mark.parametrize("integer, serialized_integer", PAIRS_INT_SERIALIZED_VALID)
  75. @pytest.mark.asyncio
  76. async def test_read_unsigned_varint(integer, serialized_integer):
  77. s = MockReader(serialized_integer)
  78. result = await read_unsigned_varint(s)
  79. assert result == integer
  80. @pytest.mark.parametrize("serialized_integer", tuple(i[1] for i in PAIRS_INT_SERIALIZED_OVERFLOW))
  81. @pytest.mark.asyncio
  82. async def test_read_unsigned_varint_overflow(serialized_integer):
  83. s = MockReader(serialized_integer)
  84. with pytest.raises(ValueError):
  85. await read_unsigned_varint(s)
  86. @pytest.mark.parametrize("max_bits", (2, 31, 32, 63, 64, 127, 128))
  87. @pytest.mark.asyncio
  88. async def test_read_write_unsigned_varint_max_bits_edge(max_bits):
  89. """
  90. Test edge cases with different `max_bits`
  91. """
  92. for i in range(-3, 0):
  93. integer = i + (2 ** max_bits)
  94. s = MockReaderWriter()
  95. await write_unsigned_varint(s, integer, max_bits=max_bits)
  96. s.seek(0, 0)
  97. result = await read_unsigned_varint(s, max_bits=max_bits)
  98. assert integer == result
  99. def test_peer_id():
  100. assert PEER_ID.to_bytes() == PEER_ID_BYTES
  101. assert PEER_ID.to_string() == PEER_ID_STRING
  102. peer_id_2 = PeerID.from_base58(PEER_ID_STRING)
  103. assert peer_id_2.to_bytes() == PEER_ID_BYTES
  104. assert peer_id_2.to_string() == PEER_ID_STRING
  105. assert PEER_ID == peer_id_2
  106. peer_id_3 = PeerID.from_base58("QmbmfNDEth7Ucvjuxiw3SP3E4PoJzbk7g4Ge6ZDigbCsNp")
  107. assert PEER_ID != peer_id_3
  108. def test_stream_info():
  109. proto = "123"
  110. si = StreamInfo(PEER_ID, MADDR, proto)
  111. assert si.peer_id == PEER_ID
  112. assert si.addr == MADDR
  113. assert si.proto == proto
  114. pb_si = si.to_protobuf()
  115. assert pb_si.peer == PEER_ID.to_bytes()
  116. assert pb_si.addr == MADDR.to_bytes()
  117. assert pb_si.proto == si.proto
  118. si_1 = StreamInfo.from_protobuf(pb_si)
  119. assert si_1.peer_id == PEER_ID
  120. assert si_1.addr == MADDR
  121. assert si_1.proto == proto
  122. def test_peer_info():
  123. pi = PeerInfo(PEER_ID, [MADDR])
  124. assert pi.peer_id == PEER_ID
  125. assert pi.addrs == [MADDR]
  126. pi_pb = p2pd_pb.PeerInfo(id=PEER_ID.to_bytes(), addrs=[MADDR.to_bytes()])
  127. pi_1 = PeerInfo.from_protobuf(pi_pb)
  128. assert pi.peer_id == pi_1.peer_id
  129. assert pi.addrs == pi_1.addrs
  130. @pytest.mark.parametrize(
  131. "maddr_str, expected_proto",
  132. (("/unix/123", protocols.P_UNIX), ("/ip4/127.0.0.1/tcp/7777", protocols.P_IP4)),
  133. )
  134. def test_parse_conn_protocol_valid(maddr_str, expected_proto):
  135. assert parse_conn_protocol(Multiaddr(maddr_str)) == expected_proto
  136. @pytest.mark.parametrize(
  137. "maddr_str",
  138. (
  139. "/p2p/QmbHVEEepCi7rn7VL7Exxpd2Ci9NNB6ifvqwhsrbRMgQFP",
  140. "/onion/timaq4ygg2iegci7:1234",
  141. ),
  142. )
  143. def test_parse_conn_protocol_invalid(maddr_str):
  144. maddr = Multiaddr(maddr_str)
  145. with pytest.raises(ValueError):
  146. parse_conn_protocol(maddr)
  147. @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
  148. def test_client_ctor_control_maddr(control_maddr_str):
  149. c = DaemonConnector(Multiaddr(control_maddr_str))
  150. assert c.control_maddr == Multiaddr(control_maddr_str)
  151. def test_client_ctor_default_control_maddr():
  152. c = DaemonConnector()
  153. assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
  154. @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
  155. def test_control_client_ctor_listen_maddr(listen_maddr_str):
  156. c = ControlClient(
  157. daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str)
  158. )
  159. assert c.listen_maddr == Multiaddr(listen_maddr_str)
  160. def test_control_client_ctor_default_listen_maddr():
  161. c = ControlClient(daemon_connector=DaemonConnector())
  162. assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
  163. @pytest.mark.parametrize(
  164. "msg_bytes",
  165. (
  166. p2pd_pb.Response(
  167. type=p2pd_pb.Response.Type.OK,
  168. identify=p2pd_pb.IdentifyResponse(
  169. id=PeerID.from_base58('QmT7WhTne9zBLfAgAJt9aiZ8jZ5BxJGowRubxsHYmnyzUd').to_bytes(),
  170. addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51126').to_bytes(),
  171. Multiaddr('/ip4/192.168.10.135/tcp/51126').to_bytes(),
  172. Multiaddr('/ip6/::1/tcp/51127').to_bytes()]
  173. )).SerializeToString(),
  174. p2pd_pb.Response(
  175. type=p2pd_pb.Response.Type.OK,
  176. identify=p2pd_pb.IdentifyResponse(
  177. id=PeerID.from_base58('QmcQFt2MFfCZ9AxzUCNrk4k7TtMdZZvAAteaA6tHpBKdrk').to_bytes(),
  178. addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51493').to_bytes(),
  179. Multiaddr('/ip4/192.168.10.135/tcp/51493').to_bytes(),
  180. Multiaddr('/ip6/::1/tcp/51494').to_bytes()]
  181. )).SerializeToString(),
  182. p2pd_pb.Response(
  183. type=p2pd_pb.Response.Type.OK,
  184. identify=p2pd_pb.IdentifyResponse(
  185. id=PeerID.from_base58('QmbWqVVoz7v9LS9ZUQAhyyfdFJY3iU8ZrUY3XQozoTA5cc').to_bytes(),
  186. addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51552').to_bytes(),
  187. Multiaddr('/ip4/192.168.10.135/tcp/51552').to_bytes(),
  188. Multiaddr('/ip6/::1/tcp/51553').to_bytes()]
  189. )).SerializeToString(),
  190. ),
  191. # give test cases ids to prevent bytes from ruining the terminal
  192. ids=("pb example Response 0", "pb example Response 1", "pb example Response 2"),
  193. )
  194. @pytest.mark.asyncio
  195. async def test_read_pbmsg_safe_valid(msg_bytes):
  196. s = MockReaderWriter()
  197. await write_unsigned_varint(s, len(msg_bytes))
  198. s.write(msg_bytes)
  199. # reset the offset back to the beginning
  200. s.seek(0, 0)
  201. pb_msg = p2pd_pb.Response()
  202. await read_pbmsg_safe(s, pb_msg)
  203. assert pb_msg.SerializeToString() == msg_bytes
  204. @pytest.mark.parametrize(
  205. "pb_type, pb_msg",
  206. (
  207. (
  208. p2pd_pb.Response,
  209. p2pd_pb.Response(
  210. type=p2pd_pb.Response.Type.OK,
  211. dht=p2pd_pb.DHTResponse(
  212. type=p2pd_pb.DHTResponse.Type.VALUE,
  213. peer=p2pd_pb.PeerInfo(
  214. id=PeerID.from_base58('QmNaXUy78W9moQ9APCoKaTtPjLcEJPN9hRBCqErY7o2fQs').to_bytes(),
  215. addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/56929').to_bytes(),
  216. Multiaddr('/ip4/192.168.10.135/tcp/56929').to_bytes(),
  217. Multiaddr('/ip6/::1/tcp/56930').to_bytes()]
  218. )
  219. )
  220. ),
  221. ),
  222. (p2pd_pb.Request, p2pd_pb.Request(type=p2pd_pb.Request.Type.LIST_PEERS)),
  223. (
  224. p2pd_pb.DHTRequest,
  225. p2pd_pb.DHTRequest(type=p2pd_pb.DHTRequest.Type.FIND_PEER,
  226. peer=PeerID.from_base58('QmcgHMuEhqdLHDVeNjiCGU7Ds6E7xK3f4amgiwHNPKKn7R').to_bytes()),
  227. ),
  228. (
  229. p2pd_pb.DHTResponse,
  230. p2pd_pb.DHTResponse(
  231. type=p2pd_pb.DHTResponse.Type.VALUE,
  232. peer=p2pd_pb.PeerInfo(
  233. id=PeerID.from_base58('QmWP32GhEyXVQsLXFvV81eadDC8zQRZxZvJK359rXxLquk').to_bytes(),
  234. addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/56897').to_bytes(),
  235. Multiaddr('/ip4/192.168.10.135/tcp/56897').to_bytes(),
  236. Multiaddr('/ip6/::1/tcp/56898').to_bytes()]
  237. )
  238. ),
  239. ),
  240. (
  241. p2pd_pb.StreamInfo,
  242. p2pd_pb.StreamInfo(peer=PeerID.from_base58('QmewLxB46MftfxQiunRgJo2W8nW4Lh5NLEkRohkHhJ4wW6').to_bytes(),
  243. addr=Multiaddr('/ip4/127.0.0.1/tcp/57029').to_bytes(),
  244. proto=b'protocol123'),
  245. ),
  246. ),
  247. ids=(
  248. "pb example Response",
  249. "pb example Request",
  250. "pb example DHTRequest",
  251. "pb example DHTResponse",
  252. "pb example StreamInfo",
  253. ),
  254. )
  255. @pytest.mark.asyncio
  256. async def test_write_pbmsg(pb_type, pb_msg):
  257. msg_bytes = bytes(chr(pb_msg.ByteSize()), 'utf-8') + pb_msg.SerializeToString()
  258. pb_obj = pb_type()
  259. s_read = MockReaderWriter(msg_bytes)
  260. await read_pbmsg_safe(s_read, pb_obj)
  261. s_write = MockReaderWriter()
  262. await write_pbmsg(s_write, pb_obj)
  263. assert msg_bytes == s_write.getvalue()
  264. @pytest.mark.parametrize(
  265. "pb_msg",
  266. (
  267. p2pd_pb.Response(),
  268. p2pd_pb.Request(),
  269. p2pd_pb.DHTRequest(),
  270. p2pd_pb.DHTResponse(),
  271. p2pd_pb.StreamInfo(),
  272. ),
  273. )
  274. @pytest.mark.asyncio
  275. async def test_write_pbmsg_missing_fields(pb_msg):
  276. with pytest.raises(EncodeError):
  277. await write_pbmsg(MockReaderWriter(), pb_msg)
  278. @pytest.fixture
  279. async def p2pcs():
  280. # TODO: Change back to gather style
  281. async with AsyncExitStack() as stack:
  282. p2pd_tuples = [
  283. await stack.enter_async_context(
  284. FUNC_MAKE_P2PD_PAIR(
  285. enable_control=ENABLE_CONTROL,
  286. enable_connmgr=ENABLE_CONNMGR,
  287. enable_dht=ENABLE_DHT,
  288. enable_pubsub=ENABLE_PUBSUB,
  289. )
  290. )
  291. for _ in range(NUM_P2PDS)
  292. ]
  293. yield tuple(p2pd_tuple.client for p2pd_tuple in p2pd_tuples)
  294. @pytest.mark.asyncio
  295. async def test_client_identify_unix_socket(p2pcs):
  296. await p2pcs[0].identify()
  297. @pytest.mark.asyncio
  298. async def test_client_identify(p2pcs):
  299. await p2pcs[0].identify()
  300. @pytest.mark.asyncio
  301. async def test_client_connect_success(p2pcs):
  302. peer_id_0, maddrs_0 = await p2pcs[0].identify()
  303. peer_id_1, maddrs_1 = await p2pcs[1].identify()
  304. await p2pcs[0].connect(peer_id_1, maddrs_1)
  305. # test case: repeated connections
  306. await p2pcs[1].connect(peer_id_0, maddrs_0)
  307. @pytest.mark.asyncio
  308. async def test_client_connect_failure(p2pcs):
  309. peer_id_1, maddrs_1 = await p2pcs[1].identify()
  310. await p2pcs[0].identify()
  311. # test case: `peer_id` mismatches
  312. with pytest.raises(ControlFailure):
  313. await p2pcs[0].connect(PEER_ID_RANDOM, maddrs_1)
  314. # test case: empty maddrs
  315. with pytest.raises(ControlFailure):
  316. await p2pcs[0].connect(peer_id_1, [])
  317. # test case: wrong maddrs
  318. with pytest.raises(ControlFailure):
  319. await p2pcs[0].connect(peer_id_1, [Multiaddr("/ip4/127.0.0.1/udp/0")])
  320. @pytest.mark.asyncio
  321. async def test_connect_safe(p2pcs):
  322. await connect_safe(p2pcs[0], p2pcs[1])
  323. @pytest.mark.asyncio
  324. async def test_client_list_peers(p2pcs):
  325. # test case: no peers
  326. assert len(await p2pcs[0].list_peers()) == 0
  327. # test case: 1 peer
  328. await connect_safe(p2pcs[0], p2pcs[1])
  329. assert len(await p2pcs[0].list_peers()) == 1
  330. assert len(await p2pcs[1].list_peers()) == 1
  331. # test case: one more peer
  332. await connect_safe(p2pcs[0], p2pcs[2])
  333. assert len(await p2pcs[0].list_peers()) == 2
  334. assert len(await p2pcs[1].list_peers()) == 1
  335. assert len(await p2pcs[2].list_peers()) == 1
  336. @pytest.mark.asyncio
  337. async def test_client_disconnect(p2pcs):
  338. # test case: disconnect a peer without connections
  339. await p2pcs[1].disconnect(PEER_ID_RANDOM)
  340. # test case: disconnect
  341. peer_id_0, _ = await p2pcs[0].identify()
  342. await connect_safe(p2pcs[0], p2pcs[1])
  343. assert len(await p2pcs[0].list_peers()) == 1
  344. assert len(await p2pcs[1].list_peers()) == 1
  345. await p2pcs[1].disconnect(peer_id_0)
  346. assert len(await p2pcs[0].list_peers()) == 0
  347. assert len(await p2pcs[1].list_peers()) == 0
  348. # test case: disconnect twice
  349. await p2pcs[1].disconnect(peer_id_0)
  350. assert len(await p2pcs[0].list_peers()) == 0
  351. assert len(await p2pcs[1].list_peers()) == 0
  352. @pytest.mark.asyncio
  353. async def test_client_stream_open_success(p2pcs):
  354. peer_id_1, maddrs_1 = await p2pcs[1].identify()
  355. await connect_safe(p2pcs[0], p2pcs[1])
  356. proto = "123"
  357. async def handle_proto(stream_info, reader, writer):
  358. await reader.readexactly(1)
  359. await p2pcs[1].stream_handler(proto, handle_proto)
  360. # test case: normal
  361. stream_info, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
  362. assert stream_info.peer_id == peer_id_1
  363. assert stream_info.addr in maddrs_1
  364. assert stream_info.proto == "123"
  365. writer.close()
  366. # test case: open with multiple protocols
  367. stream_info, reader, writer = await p2pcs[0].stream_open(
  368. peer_id_1, (proto, "another_protocol")
  369. )
  370. assert stream_info.peer_id == peer_id_1
  371. assert stream_info.addr in maddrs_1
  372. assert stream_info.proto == "123"
  373. writer.close()
  374. @pytest.mark.asyncio
  375. async def test_client_stream_open_failure(p2pcs):
  376. peer_id_1, _ = await p2pcs[1].identify()
  377. await connect_safe(p2pcs[0], p2pcs[1])
  378. proto = "123"
  379. # test case: `stream_open` to a peer who didn't register the protocol
  380. with pytest.raises(ControlFailure):
  381. await p2pcs[0].stream_open(peer_id_1, (proto,))
  382. # test case: `stream_open` to a peer for a non-registered protocol
  383. async def handle_proto(stream_info, reader, writer):
  384. pass
  385. await p2pcs[1].stream_handler(proto, handle_proto)
  386. with pytest.raises(ControlFailure):
  387. await p2pcs[0].stream_open(peer_id_1, ("another_protocol",))
  388. @pytest.mark.asyncio
  389. async def test_client_stream_handler_success(p2pcs):
  390. peer_id_1, _ = await p2pcs[1].identify()
  391. await connect_safe(p2pcs[0], p2pcs[1])
  392. proto = "protocol123"
  393. bytes_to_send = b"yoyoyoyoyog"
  394. # event for this test function to wait until the handler function receiving the incoming data
  395. event_handler_finished = asyncio.Event()
  396. async def handle_proto(stream_info, reader, writer):
  397. nonlocal event_handler_finished
  398. bytes_received = await reader.readexactly(len(bytes_to_send))
  399. assert bytes_received == bytes_to_send
  400. event_handler_finished.set()
  401. await p2pcs[1].stream_handler(proto, handle_proto)
  402. assert proto in p2pcs[1].control.handlers
  403. assert handle_proto == p2pcs[1].control.handlers[proto]
  404. # test case: test the stream handler `handle_proto`
  405. _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
  406. # wait until the handler function starts blocking waiting for the data
  407. # because we haven't sent the data, we know the handler function must still blocking waiting.
  408. # get the task of the protocol handler
  409. writer.write(bytes_to_send)
  410. # wait for the handler to finish
  411. writer.close()
  412. await event_handler_finished.wait()
  413. # test case: two streams to different handlers respectively
  414. another_proto = "another_protocol123"
  415. another_bytes_to_send = b"456"
  416. event_another_proto = asyncio.Event()
  417. async def handle_another_proto(stream_info, reader, writer):
  418. event_another_proto.set()
  419. bytes_received = await reader.readexactly(len(another_bytes_to_send))
  420. assert bytes_received == another_bytes_to_send
  421. await p2pcs[1].stream_handler(another_proto, handle_another_proto)
  422. assert another_proto in p2pcs[1].control.handlers
  423. assert handle_another_proto == p2pcs[1].control.handlers[another_proto]
  424. _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (another_proto,))
  425. await event_another_proto.wait()
  426. # we know at this moment the handler must still blocking wait
  427. writer.write(another_bytes_to_send)
  428. writer.close()
  429. # test case: registering twice can override the previous registration
  430. event_third = asyncio.Event()
  431. async def handler_third(stream_info, reader, writer):
  432. event_third.set()
  433. await p2pcs[1].stream_handler(another_proto, handler_third)
  434. assert another_proto in p2pcs[1].control.handlers
  435. # ensure the handler is override
  436. assert handler_third == p2pcs[1].control.handlers[another_proto]
  437. await p2pcs[0].stream_open(peer_id_1, (another_proto,))
  438. # ensure the overriding handler is called when the protocol is opened a stream
  439. await event_third.wait()
  440. @pytest.mark.asyncio
  441. async def test_client_stream_handler_failure(p2pcs):
  442. peer_id_1, _ = await p2pcs[1].identify()
  443. await connect_safe(p2pcs[0], p2pcs[1])
  444. proto = "123"
  445. # test case: registered a wrong protocol name
  446. async def handle_proto_correct_params(stream_info, stream):
  447. pass
  448. await p2pcs[1].stream_handler("another_protocol", handle_proto_correct_params)
  449. with pytest.raises(ControlFailure):
  450. await p2pcs[0].stream_open(peer_id_1, (proto,))