浏览代码

Improve P2P handler throughput and interface (#316)

This PR implements the following:

1. Implement best practices for sending large data in `P2P.send_raw_data`:

    - No concatenation and copying of large byte strings
    - Sending data by chunks (64 KiB by default)
    - Awaiting `writer.drain()`

    Together with @justheuristic, we have discovered that these changes significantly improve throughput and latency of sending large amounts of data.

2. Change stream handler interface to allow reading and writing concurrently (basically, reverting them to the original stream handler interface of p2p-daemon-bindings).

3. Make meaningful tests for stream handlers:

    - Remove tests simulating msgpack-based unary handlers (the unary handlers are already tested + raw stream handlers are tested inside the tests for p2p-daemon-bindings)
    - Add tests for streaming functionality
    - Add tests for a stream handler raising an exception

4. Minor refactorings:

    - Remove unnecessary type argument for `P2P.send_protobuf()`
    - Remove unused `P2P.receive_msgpack()` and `P2P.send_msgpack()`
    - Use `contextlib.closing()` decorator
    - Rename `handle` to `handler` for consistency
Alexander Borzunov 4 年之前
父节点
当前提交
4f4b3abd2d
共有 3 个文件被更改,包括 128 次插入258 次删除
  1. 26 64
      hivemind/p2p/p2p_daemon.py
  2. 2 4
      hivemind/p2p/servicer.py
  3. 100 190
      tests/test_p2p_daemon.py

+ 26 - 64
hivemind/p2p/p2p_daemon.py

@@ -1,11 +1,11 @@
 import asyncio
 import os
 import secrets
-from contextlib import suppress
+from contextlib import closing, suppress
 from dataclasses import dataclass
 from importlib.resources import path
 from subprocess import Popen
-from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
+from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
 
 import google.protobuf
 from multiaddr import Multiaddr
@@ -14,7 +14,6 @@ import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.proto import p2pd_pb2
-from hivemind.utils import MSGPackSerializer
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
@@ -247,20 +246,16 @@ class P2P:
         return self._daemon_listen_maddr
 
     @staticmethod
-    async def send_raw_data(data: bytes, writer: asyncio.StreamWriter) -> None:
-        request = len(data).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + data
-        writer.write(request)
+    async def send_raw_data(data: bytes, writer: asyncio.StreamWriter, *, chunk_size: int = 2 ** 16) -> None:
+        writer.write(len(data).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER))
+        data = memoryview(data)
+        for offset in range(0, len(data), chunk_size):
+            writer.write(data[offset : offset + chunk_size])
+        await writer.drain()
 
     @staticmethod
-    async def send_msgpack(data: Any, writer: asyncio.StreamWriter) -> None:
-        raw_data = MSGPackSerializer.dumps(data)
-        await P2P.send_raw_data(raw_data, writer)
-
-    @staticmethod
-    async def send_protobuf(protobuf, out_proto_type: type, writer: asyncio.StreamWriter) -> None:
-        if type(protobuf) != out_proto_type:
-            raise TypeError("Unary handler returned protobuf of wrong type.")
-        if out_proto_type == p2pd_pb2.RPCError:
+    async def send_protobuf(protobuf, writer: asyncio.StreamWriter) -> None:
+        if isinstance(protobuf, p2pd_pb2.RPCError):
             await P2P.send_raw_data(P2P.ERROR_MESSAGE, writer)
         else:
             await P2P.send_raw_data(P2P.RESULT_MESSAGE, writer)
@@ -274,10 +269,6 @@ class P2P:
         data = await reader.readexactly(content_length)
         return data
 
-    @staticmethod
-    async def receive_msgpack(reader: asyncio.StreamReader) -> Any:
-        return MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
-
     @staticmethod
     async def receive_protobuf(
         in_proto_type: type, reader: asyncio.StreamReader
@@ -294,28 +285,7 @@ class P2P:
         else:
             raise TypeError("Invalid Protobuf message type")
 
-    @staticmethod
-    def _handle_stream(handle: Callable[[bytes], bytes]):
-        async def do_handle_stream(
-            stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
-        ):
-            try:
-                request = await P2P.receive_raw_data(reader)
-            except asyncio.IncompleteReadError:
-                logger.debug("Incomplete read while receiving request from peer")
-                writer.close()
-                return
-            try:
-                result = handle(request)
-                await P2P.send_raw_data(result, writer)
-            finally:
-                writer.close()
-
-        return do_handle_stream
-
-    def _handle_unary_stream(
-        self, handle: Callable[[Any, P2PContext], Any], handle_name: str, in_proto_type: type, out_proto_type: type
-    ):
+    def _handle_unary_stream(self, handler: Callable[[Any, P2PContext], Any], handle_name: str, in_proto_type: type):
         async def watchdog(reader: asyncio.StreamReader) -> None:
             await reader.read(n=1)
             raise P2PInterruptedError()
@@ -323,7 +293,7 @@ class P2P:
         async def do_handle_unary_stream(
             stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
         ) -> None:
-            try:
+            with closing(writer):
                 try:
                     request, err = await P2P.receive_protobuf(in_proto_type, reader)
                 except asyncio.IncompleteReadError:
@@ -344,23 +314,21 @@ class P2P:
                     remote_maddr=stream_info.addr,
                 )
                 done, pending = await asyncio.wait(
-                    [watchdog(reader), handle(request, context)], return_when=asyncio.FIRST_COMPLETED
+                    [watchdog(reader), handler(request, context)], return_when=asyncio.FIRST_COMPLETED
                 )
                 try:
                     result = done.pop().result()
-                    await P2P.send_protobuf(result, out_proto_type, writer)
+                    await P2P.send_protobuf(result, writer)
                 except P2PInterruptedError:
                     pass
                 except Exception as exc:
                     error = p2pd_pb2.RPCError(message=str(exc))
-                    await P2P.send_protobuf(error, p2pd_pb2.RPCError, writer)
+                    await P2P.send_protobuf(error, writer)
                 finally:
                     if pending:
                         for task in pending:
                             task.cancel()
                         await asyncio.wait(pending)
-            finally:
-                writer.close()
 
         return do_handle_unary_stream
 
@@ -381,39 +349,33 @@ class P2P:
                 self._listen_task = None
                 self._server_stopped.clear()
 
-    async def add_stream_handler(self, name: str, handle: Callable[[bytes], bytes]) -> None:
+    async def add_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
         if self._listen_task is None:
             self._start_listening()
-        await self._client.stream_handler(name, self._handle_stream(handle))
+        await self._client.stream_handler(name, handler)
 
     async def add_unary_handler(
-        self, name: str, handle: Callable[[Any, P2PContext], Any], in_proto_type: type, out_proto_type: type
+        self, name: str, handler: Callable[[Any, P2PContext], Any], in_proto_type: type
     ) -> None:
         if self._listen_task is None:
             self._start_listening()
-        await self._client.stream_handler(name, self._handle_unary_stream(handle, name, in_proto_type, out_proto_type))
+        await self._client.stream_handler(name, self._handle_unary_stream(handler, name, in_proto_type))
 
-    async def call_peer_handler(self, peer_id: PeerID, handler_name: str, input_data: bytes) -> bytes:
-        stream_info, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
-        try:
-            await P2P.send_raw_data(input_data, writer)
-            return await P2P.receive_raw_data(reader)
-        finally:
-            writer.close()
+    async def call_stream_handler(
+        self, peer_id: PeerID, handler_name: str
+    ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
+        return await self._client.stream_open(peer_id, (handler_name,))
 
     async def call_unary_handler(
         self, peer_id: PeerID, handler_name: str, request_protobuf: Any, response_proto_type: type
     ) -> Any:
-        stream_info, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
-        try:
-            await P2P.send_protobuf(request_protobuf, type(request_protobuf), writer)
+        _, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
+        with closing(writer):
+            await P2P.send_protobuf(request_protobuf, writer)
             result, err = await P2P.receive_protobuf(response_proto_type, reader)
             if err is not None:
                 raise P2PHandlerError(f"Failed to call unary handler {handler_name} at {peer_id}: {err.message}")
-
             return result
-        finally:
-            writer.close()
 
     def __del__(self):
         self._terminate()

+ 2 - 4
hivemind/p2p/servicer.py

@@ -1,10 +1,9 @@
 import asyncio
 import importlib
 from dataclasses import dataclass
-from functools import partial
-from typing import Any, Callable, Optional, Union
+from typing import Any, Optional, Union
 
-from hivemind.p2p.p2p_daemon import P2P, P2PContext
+from hivemind.p2p.p2p_daemon import P2P
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
 
 
@@ -89,7 +88,6 @@ class Servicer:
                 handler.handle_name,
                 getattr(servicer, handler.method_name),
                 handler.request_type,
-                handler.response_type,
             )
 
     def get_stub(self, p2p: P2P, peer: PeerID) -> StubBase:

+ 100 - 190
tests/test_p2p_daemon.py

@@ -1,18 +1,17 @@
 import asyncio
 import multiprocessing as mp
 import subprocess
+from contextlib import closing
 from functools import partial
 from typing import List
 
 import numpy as np
 import pytest
-import torch
 from multiaddr import Multiaddr
 
 from hivemind.p2p import P2P, P2PHandlerError
-from hivemind.proto import dht_pb2, runtime_pb2
-from hivemind.utils import MSGPackSerializer
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.proto import dht_pb2
+from hivemind.utils.serializer import MSGPackSerializer
 
 
 def is_process_running(pid: int) -> bool:
@@ -23,13 +22,6 @@ async def replicate_if_needed(p2p: P2P, replicate: bool) -> P2P:
     return await P2P.replicate(p2p.daemon_listen_maddr) if replicate else p2p
 
 
-async def bootstrap_from(daemons: List[P2P]) -> List[Multiaddr]:
-    maddrs = []
-    for d in daemons:
-        maddrs += await d.get_visible_maddrs()
-    return maddrs
-
-
 @pytest.mark.asyncio
 async def test_daemon_killed_on_del():
     p2p_daemon = await P2P.create()
@@ -55,8 +47,7 @@ async def test_transports(host_maddrs: List[Multiaddr]):
     peers = await server.list_peers()
     assert len(peers) == 0
 
-    nodes = await bootstrap_from([server])
-    client = await P2P.create(quic=True, host_maddrs=host_maddrs, initial_peers=nodes)
+    client = await P2P.create(quic=True, host_maddrs=host_maddrs, initial_peers=await server.get_visible_maddrs())
     await client.wait_for_at_least_n_peers(1)
 
     peers = await client.list_peers()
@@ -80,48 +71,6 @@ async def test_daemon_replica_does_not_affect_primary():
     assert not is_process_running(child_pid)
 
 
-def handle_square(x):
-    x = MSGPackSerializer.loads(x)
-    return MSGPackSerializer.dumps(x ** 2)
-
-
-def handle_add(args):
-    args = MSGPackSerializer.loads(args)
-    result = args[0]
-    for i in range(1, len(args)):
-        result = result + args[i]
-    return MSGPackSerializer.dumps(result)
-
-
-def handle_square_torch(x):
-    tensor = runtime_pb2.Tensor()
-    tensor.ParseFromString(x)
-    tensor = deserialize_torch_tensor(tensor)
-    result = tensor ** 2
-    return serialize_torch_tensor(result).SerializeToString()
-
-
-def handle_add_torch(args):
-    args = MSGPackSerializer.loads(args)
-    tensor = runtime_pb2.Tensor()
-    tensor.ParseFromString(args[0])
-    result = deserialize_torch_tensor(tensor)
-
-    for i in range(1, len(args)):
-        tensor = runtime_pb2.Tensor()
-        tensor.ParseFromString(args[i])
-        result = result + deserialize_torch_tensor(tensor)
-
-    return serialize_torch_tensor(result).SerializeToString()
-
-
-def handle_add_torch_with_exc(args):
-    try:
-        return handle_add_torch(args)
-    except Exception:
-        return b"something went wrong :("
-
-
 @pytest.mark.parametrize(
     "should_cancel,replicate",
     [
@@ -146,11 +95,10 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
         return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
 
     server_pid = server_primary._child.pid
-    await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
+    await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest)
     assert is_process_running(server_pid)
 
-    nodes = await bootstrap_from([server])
-    client_primary = await P2P.create(initial_peers=nodes)
+    client_primary = await P2P.create(initial_peers=await server.get_visible_maddrs())
     client = await replicate_if_needed(client_primary, replicate)
     client_pid = client_primary._child.pid
     assert is_process_running(client_pid)
@@ -160,10 +108,10 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
     expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
 
     if should_cancel:
-        stream_info, reader, writer = await client._client.stream_open(server.id, (handle_name,))
-        await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
+        *_, writer = await client.call_stream_handler(server.id, handle_name)
+        with closing(writer):
+            await P2P.send_protobuf(ping_request, writer)
 
-        writer.close()
         await asyncio.sleep(1)
         assert handler_cancelled
     else:
@@ -186,11 +134,10 @@ async def test_call_unary_handler_error(handle_name="handle"):
 
     server = await P2P.create()
     server_pid = server._child.pid
-    await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
+    await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest)
     assert is_process_running(server_pid)
 
-    nodes = await bootstrap_from([server])
-    client = await P2P.create(initial_peers=nodes)
+    client = await P2P.create(initial_peers=await server.get_visible_maddrs())
     client_pid = client._child.pid
     assert is_process_running(client_pid)
     await client.wait_for_at_least_n_peers(1)
@@ -205,33 +152,47 @@ async def test_call_unary_handler_error(handle_name="handle"):
     await client.shutdown()
 
 
-@pytest.mark.parametrize(
-    "test_input,expected,handle",
-    [
-        pytest.param(10, 100, handle_square, id="square_integer"),
-        pytest.param((1, 2), 3, handle_add, id="add_integers"),
-        pytest.param(([1, 2, 3], [12, 13]), [1, 2, 3, 12, 13], handle_add, id="add_lists"),
-        pytest.param(2, 8, lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x) ** 3), id="lambda"),
-    ],
-)
+async def handle_square_stream(_, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
+    with closing(writer):
+        while True:
+            try:
+                x = MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
+            except asyncio.IncompleteReadError:
+                break
+
+            result = x ** 2
+
+            await P2P.send_raw_data(MSGPackSerializer.dumps(result), writer)
+
+
+async def validate_square_stream(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
+    with closing(writer):
+        for _ in range(10):
+            x = np.random.randint(100)
+
+            await P2P.send_raw_data(MSGPackSerializer.dumps(x), writer)
+            result = MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
+
+            assert result == x ** 2
+
+
 @pytest.mark.asyncio
-async def test_call_peer_single_process(test_input, expected, handle, handler_name="handle"):
+async def test_call_peer_single_process():
     server = await P2P.create()
     server_pid = server._child.pid
-    await server.add_stream_handler(handler_name, handle)
     assert is_process_running(server_pid)
 
-    nodes = await bootstrap_from([server])
-    client = await P2P.create(initial_peers=nodes)
+    handler_name = "square"
+    await server.add_stream_handler(handler_name, handle_square_stream)
+
+    client = await P2P.create(initial_peers=await server.get_visible_maddrs())
     client_pid = client._child.pid
     assert is_process_running(client_pid)
 
     await client.wait_for_at_least_n_peers(1)
 
-    test_input_msgp = MSGPackSerializer.dumps(test_input)
-    result_msgp = await client.call_peer_handler(server.id, handler_name, test_input_msgp)
-    result = MSGPackSerializer.loads(result_msgp)
-    assert result == expected
+    _, reader, writer = await client.call_stream_handler(server.id, handler_name)
+    await validate_square_stream(reader, writer)
 
     await server.shutdown()
     assert not is_process_running(server_pid)
@@ -240,12 +201,13 @@ async def test_call_peer_single_process(test_input, expected, handle, handler_na
     assert not is_process_running(client_pid)
 
 
-async def run_server(handler_name, server_side, client_side, response_received):
+async def run_server(handler_name, server_side, response_received):
     server = await P2P.create()
     server_pid = server._child.pid
-    await server.add_stream_handler(handler_name, handle_square)
     assert is_process_running(server_pid)
 
+    await server.add_stream_handler(handler_name, handle_square_stream)
+
     server_side.send(server.id)
     server_side.send(await server.get_visible_maddrs())
     while response_received.value == 0:
@@ -255,20 +217,19 @@ async def run_server(handler_name, server_side, client_side, response_received):
     assert not is_process_running(server_pid)
 
 
-def server_target(handler_name, server_side, client_side, response_received):
-    asyncio.run(run_server(handler_name, server_side, client_side, response_received))
+def server_target(handler_name, server_side, response_received):
+    asyncio.run(run_server(handler_name, server_side, response_received))
 
 
 @pytest.mark.asyncio
 async def test_call_peer_different_processes():
     handler_name = "square"
-    test_input = 2
 
     server_side, client_side = mp.Pipe()
     response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
     response_received.value = 0
 
-    proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
+    proc = mp.Process(target=server_target, args=(handler_name, server_side, response_received))
     proc.start()
 
     peer_id = client_side.recv()
@@ -280,146 +241,95 @@ async def test_call_peer_different_processes():
 
     await client.wait_for_at_least_n_peers(1)
 
-    test_input_msgp = MSGPackSerializer.dumps(2)
-    result_msgp = await client.call_peer_handler(peer_id, handler_name, test_input_msgp)
-    result = MSGPackSerializer.loads(result_msgp)
-    assert np.allclose(result, test_input ** 2)
+    _, reader, writer = await client.call_stream_handler(peer_id, handler_name)
+    await validate_square_stream(reader, writer)
+
     response_received.value = 1
 
     await client.shutdown()
     assert not is_process_running(client_pid)
 
     proc.join()
+    assert proc.exitcode == 0
 
 
-@pytest.mark.parametrize(
-    "test_input,expected",
-    [
-        pytest.param(torch.tensor([2]), torch.tensor(4)),
-        pytest.param(torch.tensor([[1.0, 2.0], [0.5, 0.1]]), torch.tensor([[1.0, 2.0], [0.5, 0.1]]) ** 2),
-    ],
-)
 @pytest.mark.asyncio
-async def test_call_peer_torch_square(test_input, expected, handler_name="handle"):
-    handle = handle_square_torch
-    server = await P2P.create()
-    await server.add_stream_handler(handler_name, handle)
-
-    nodes = await bootstrap_from([server])
-    client = await P2P.create(initial_peers=nodes)
-
-    await client.wait_for_at_least_n_peers(1)
-
-    inp = serialize_torch_tensor(test_input).SerializeToString()
-    result_pb = await client.call_peer_handler(server.id, handler_name, inp)
-    result = runtime_pb2.Tensor()
-    result.ParseFromString(result_pb)
-    result = deserialize_torch_tensor(result)
-    assert torch.allclose(result, expected)
+async def test_error_closes_connection():
+    async def handle_raising_error(_, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
+        with closing(writer):
+            command = await P2P.receive_raw_data(reader)
+            if command == b"raise_error":
+                raise Exception("The handler has failed")
+            else:
+                await P2P.send_raw_data(b"okay", writer)
 
-    await server.shutdown()
-    await client.shutdown()
-
-
-@pytest.mark.parametrize(
-    "test_input,expected",
-    [
-        pytest.param([torch.tensor([1]), torch.tensor([2])], torch.tensor([3])),
-        pytest.param(
-            [torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([[1.1, 1.2], [1.3, 1.4]])],
-            torch.tensor([[1.2, 1.4], [1.6, 1.8]]),
-        ),
-    ],
-)
-@pytest.mark.asyncio
-async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
-    handle = handle_add_torch
     server = await P2P.create()
-    await server.add_stream_handler(handler_name, handle)
-
-    nodes = await bootstrap_from([server])
-    client = await P2P.create(initial_peers=nodes)
-
-    await client.wait_for_at_least_n_peers(1)
+    server_pid = server._child.pid
+    assert is_process_running(server_pid)
 
-    inp = [serialize_torch_tensor(i).SerializeToString() for i in test_input]
-    inp_msgp = MSGPackSerializer.dumps(inp)
-    result_pb = await client.call_peer_handler(server.id, handler_name, inp_msgp)
-    result = runtime_pb2.Tensor()
-    result.ParseFromString(result_pb)
-    result = deserialize_torch_tensor(result)
-    assert torch.allclose(result, expected)
-
-    await server.shutdown()
-    await client.shutdown()
+    handler_name = "handler"
+    await server.add_stream_handler(handler_name, handle_raising_error)
 
+    client = await P2P.create(initial_peers=await server.get_visible_maddrs())
+    client_pid = client._child.pid
+    assert is_process_running(client_pid)
 
-@pytest.mark.parametrize(
-    "replicate",
-    [
-        pytest.param(False, id="primary"),
-        pytest.param(True, id="replica"),
-    ],
-)
-@pytest.mark.asyncio
-async def test_call_peer_error(replicate, handler_name="handle"):
-    server_primary = await P2P.create()
-    server = await replicate_if_needed(server_primary, replicate)
-    await server.add_stream_handler(handler_name, handle_add_torch_with_exc)
+    await client.wait_for_at_least_n_peers(1)
 
-    nodes = await bootstrap_from([server])
-    client_primary = await P2P.create(initial_peers=nodes)
-    client = await replicate_if_needed(client_primary, replicate)
+    _, reader, writer = await client.call_stream_handler(server.id, handler_name)
+    with closing(writer):
+        await P2P.send_raw_data(b"raise_error", writer)
+        with pytest.raises(asyncio.IncompleteReadError):  # Means that the connection is closed
+            await P2P.receive_raw_data(reader)
 
-    await client.wait_for_at_least_n_peers(1)
+    # Despite the handler raised an exception, the server did not crash and ready for next requests
+    assert is_process_running(server_pid)
 
-    inp = [serialize_torch_tensor(i).SerializeToString() for i in [torch.zeros((2, 3)), torch.zeros((3, 2))]]
-    inp_msgp = MSGPackSerializer.dumps(inp)
-    result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
-    assert result == b"something went wrong :("
+    _, reader, writer = await client.call_stream_handler(server.id, handler_name)
+    with closing(writer):
+        await P2P.send_raw_data(b"behave_normally", writer)
+        assert await P2P.receive_raw_data(reader) == b"okay"
 
-    await server_primary.shutdown()
     await server.shutdown()
-    await client_primary.shutdown()
+    assert not is_process_running(server_pid)
+
     await client.shutdown()
+    assert not is_process_running(client_pid)
 
 
 @pytest.mark.asyncio
-async def test_handlers_on_different_replicas(handler_name="handle"):
-    def handler(arg, key):
-        return key
+async def test_handlers_on_different_replicas():
+    async def handler(_, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, key: str) -> None:
+        with closing(writer):
+            await P2P.send_raw_data(key, writer)
 
     server_primary = await P2P.create()
     server_id = server_primary.id
-    await server_primary.add_stream_handler(handler_name, partial(handler, key=b"primary"))
+    await server_primary.add_stream_handler("handle_primary", partial(handler, key=b"primary"))
 
     server_replica1 = await replicate_if_needed(server_primary, True)
-    await server_replica1.add_stream_handler(handler_name + "1", partial(handler, key=b"replica1"))
+    await server_replica1.add_stream_handler("handle1", partial(handler, key=b"replica1"))
 
     server_replica2 = await replicate_if_needed(server_primary, True)
-    await server_replica2.add_stream_handler(handler_name + "2", partial(handler, key=b"replica2"))
+    await server_replica2.add_stream_handler("handle2", partial(handler, key=b"replica2"))
 
-    nodes = await bootstrap_from([server_primary])
-    client = await P2P.create(initial_peers=nodes)
+    client = await P2P.create(initial_peers=await server_primary.get_visible_maddrs())
     await client.wait_for_at_least_n_peers(1)
 
-    result = await client.call_peer_handler(server_id, handler_name, b"1")
-    assert result == b"primary"
-
-    result = await client.call_peer_handler(server_id, handler_name + "1", b"2")
-    assert result == b"replica1"
-
-    result = await client.call_peer_handler(server_id, handler_name + "2", b"3")
-    assert result == b"replica2"
+    for name, expected_key in [("handle_primary", b"primary"), ("handle1", b"replica1"), ("handle2", b"replica2")]:
+        _, reader, writer = await client.call_stream_handler(server_id, name)
+        with closing(writer):
+            assert await P2P.receive_raw_data(reader) == expected_key
 
     await server_replica1.shutdown()
     await server_replica2.shutdown()
 
-    # Primary does not handle replicas protocols
-    with pytest.raises(Exception):
-        await client.call_peer_handler(server_id, handler_name + "1", b"")
-    with pytest.raises(Exception):
-        await client.call_peer_handler(server_id, handler_name + "2", b"")
+    # Primary does not handle replicas protocols after their shutdown
+
+    for name in ["handle1", "handle2"]:
+        _, reader, writer = await client.call_stream_handler(server_id, name)
+        with pytest.raises(asyncio.IncompleteReadError), closing(writer):
+            await P2P.receive_raw_data(reader)
 
     await server_primary.shutdown()
     await client.shutdown()