|
@@ -1,18 +1,17 @@
|
|
|
import asyncio
|
|
|
import multiprocessing as mp
|
|
|
import subprocess
|
|
|
+from contextlib import closing
|
|
|
from functools import partial
|
|
|
from typing import List
|
|
|
|
|
|
import numpy as np
|
|
|
import pytest
|
|
|
-import torch
|
|
|
from multiaddr import Multiaddr
|
|
|
|
|
|
from hivemind.p2p import P2P, P2PHandlerError
|
|
|
-from hivemind.proto import dht_pb2, runtime_pb2
|
|
|
-from hivemind.utils import MSGPackSerializer
|
|
|
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
+from hivemind.proto import dht_pb2
|
|
|
+from hivemind.utils.serializer import MSGPackSerializer
|
|
|
|
|
|
|
|
|
def is_process_running(pid: int) -> bool:
|
|
@@ -23,13 +22,6 @@ async def replicate_if_needed(p2p: P2P, replicate: bool) -> P2P:
|
|
|
return await P2P.replicate(p2p.daemon_listen_maddr) if replicate else p2p
|
|
|
|
|
|
|
|
|
-async def bootstrap_from(daemons: List[P2P]) -> List[Multiaddr]:
|
|
|
- maddrs = []
|
|
|
- for d in daemons:
|
|
|
- maddrs += await d.get_visible_maddrs()
|
|
|
- return maddrs
|
|
|
-
|
|
|
-
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_daemon_killed_on_del():
|
|
|
p2p_daemon = await P2P.create()
|
|
@@ -55,8 +47,7 @@ async def test_transports(host_maddrs: List[Multiaddr]):
|
|
|
peers = await server.list_peers()
|
|
|
assert len(peers) == 0
|
|
|
|
|
|
- nodes = await bootstrap_from([server])
|
|
|
- client = await P2P.create(quic=True, host_maddrs=host_maddrs, initial_peers=nodes)
|
|
|
+ client = await P2P.create(quic=True, host_maddrs=host_maddrs, initial_peers=await server.get_visible_maddrs())
|
|
|
await client.wait_for_at_least_n_peers(1)
|
|
|
|
|
|
peers = await client.list_peers()
|
|
@@ -80,48 +71,6 @@ async def test_daemon_replica_does_not_affect_primary():
|
|
|
assert not is_process_running(child_pid)
|
|
|
|
|
|
|
|
|
-def handle_square(x):
|
|
|
- x = MSGPackSerializer.loads(x)
|
|
|
- return MSGPackSerializer.dumps(x ** 2)
|
|
|
-
|
|
|
-
|
|
|
-def handle_add(args):
|
|
|
- args = MSGPackSerializer.loads(args)
|
|
|
- result = args[0]
|
|
|
- for i in range(1, len(args)):
|
|
|
- result = result + args[i]
|
|
|
- return MSGPackSerializer.dumps(result)
|
|
|
-
|
|
|
-
|
|
|
-def handle_square_torch(x):
|
|
|
- tensor = runtime_pb2.Tensor()
|
|
|
- tensor.ParseFromString(x)
|
|
|
- tensor = deserialize_torch_tensor(tensor)
|
|
|
- result = tensor ** 2
|
|
|
- return serialize_torch_tensor(result).SerializeToString()
|
|
|
-
|
|
|
-
|
|
|
-def handle_add_torch(args):
|
|
|
- args = MSGPackSerializer.loads(args)
|
|
|
- tensor = runtime_pb2.Tensor()
|
|
|
- tensor.ParseFromString(args[0])
|
|
|
- result = deserialize_torch_tensor(tensor)
|
|
|
-
|
|
|
- for i in range(1, len(args)):
|
|
|
- tensor = runtime_pb2.Tensor()
|
|
|
- tensor.ParseFromString(args[i])
|
|
|
- result = result + deserialize_torch_tensor(tensor)
|
|
|
-
|
|
|
- return serialize_torch_tensor(result).SerializeToString()
|
|
|
-
|
|
|
-
|
|
|
-def handle_add_torch_with_exc(args):
|
|
|
- try:
|
|
|
- return handle_add_torch(args)
|
|
|
- except Exception:
|
|
|
- return b"something went wrong :("
|
|
|
-
|
|
|
-
|
|
|
@pytest.mark.parametrize(
|
|
|
"should_cancel,replicate",
|
|
|
[
|
|
@@ -146,11 +95,10 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
|
|
|
return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
|
|
|
|
|
|
server_pid = server_primary._child.pid
|
|
|
- await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
|
|
|
+ await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest)
|
|
|
assert is_process_running(server_pid)
|
|
|
|
|
|
- nodes = await bootstrap_from([server])
|
|
|
- client_primary = await P2P.create(initial_peers=nodes)
|
|
|
+ client_primary = await P2P.create(initial_peers=await server.get_visible_maddrs())
|
|
|
client = await replicate_if_needed(client_primary, replicate)
|
|
|
client_pid = client_primary._child.pid
|
|
|
assert is_process_running(client_pid)
|
|
@@ -160,10 +108,10 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
|
|
|
expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
|
|
|
|
|
|
if should_cancel:
|
|
|
- stream_info, reader, writer = await client._client.stream_open(server.id, (handle_name,))
|
|
|
- await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
|
|
|
+ *_, writer = await client.call_stream_handler(server.id, handle_name)
|
|
|
+ with closing(writer):
|
|
|
+ await P2P.send_protobuf(ping_request, writer)
|
|
|
|
|
|
- writer.close()
|
|
|
await asyncio.sleep(1)
|
|
|
assert handler_cancelled
|
|
|
else:
|
|
@@ -186,11 +134,10 @@ async def test_call_unary_handler_error(handle_name="handle"):
|
|
|
|
|
|
server = await P2P.create()
|
|
|
server_pid = server._child.pid
|
|
|
- await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
|
|
|
+ await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest)
|
|
|
assert is_process_running(server_pid)
|
|
|
|
|
|
- nodes = await bootstrap_from([server])
|
|
|
- client = await P2P.create(initial_peers=nodes)
|
|
|
+ client = await P2P.create(initial_peers=await server.get_visible_maddrs())
|
|
|
client_pid = client._child.pid
|
|
|
assert is_process_running(client_pid)
|
|
|
await client.wait_for_at_least_n_peers(1)
|
|
@@ -205,33 +152,47 @@ async def test_call_unary_handler_error(handle_name="handle"):
|
|
|
await client.shutdown()
|
|
|
|
|
|
|
|
|
-@pytest.mark.parametrize(
|
|
|
- "test_input,expected,handle",
|
|
|
- [
|
|
|
- pytest.param(10, 100, handle_square, id="square_integer"),
|
|
|
- pytest.param((1, 2), 3, handle_add, id="add_integers"),
|
|
|
- pytest.param(([1, 2, 3], [12, 13]), [1, 2, 3, 12, 13], handle_add, id="add_lists"),
|
|
|
- pytest.param(2, 8, lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x) ** 3), id="lambda"),
|
|
|
- ],
|
|
|
-)
|
|
|
+async def handle_square_stream(_, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
|
|
+ with closing(writer):
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ x = MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
|
|
|
+ except asyncio.IncompleteReadError:
|
|
|
+ break
|
|
|
+
|
|
|
+ result = x ** 2
|
|
|
+
|
|
|
+ await P2P.send_raw_data(MSGPackSerializer.dumps(result), writer)
|
|
|
+
|
|
|
+
|
|
|
+async def validate_square_stream(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
|
|
+ with closing(writer):
|
|
|
+ for _ in range(10):
|
|
|
+ x = np.random.randint(100)
|
|
|
+
|
|
|
+ await P2P.send_raw_data(MSGPackSerializer.dumps(x), writer)
|
|
|
+ result = MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
|
|
|
+
|
|
|
+ assert result == x ** 2
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.asyncio
|
|
|
-async def test_call_peer_single_process(test_input, expected, handle, handler_name="handle"):
|
|
|
+async def test_call_peer_single_process():
|
|
|
server = await P2P.create()
|
|
|
server_pid = server._child.pid
|
|
|
- await server.add_stream_handler(handler_name, handle)
|
|
|
assert is_process_running(server_pid)
|
|
|
|
|
|
- nodes = await bootstrap_from([server])
|
|
|
- client = await P2P.create(initial_peers=nodes)
|
|
|
+ handler_name = "square"
|
|
|
+ await server.add_stream_handler(handler_name, handle_square_stream)
|
|
|
+
|
|
|
+ client = await P2P.create(initial_peers=await server.get_visible_maddrs())
|
|
|
client_pid = client._child.pid
|
|
|
assert is_process_running(client_pid)
|
|
|
|
|
|
await client.wait_for_at_least_n_peers(1)
|
|
|
|
|
|
- test_input_msgp = MSGPackSerializer.dumps(test_input)
|
|
|
- result_msgp = await client.call_peer_handler(server.id, handler_name, test_input_msgp)
|
|
|
- result = MSGPackSerializer.loads(result_msgp)
|
|
|
- assert result == expected
|
|
|
+ _, reader, writer = await client.call_stream_handler(server.id, handler_name)
|
|
|
+ await validate_square_stream(reader, writer)
|
|
|
|
|
|
await server.shutdown()
|
|
|
assert not is_process_running(server_pid)
|
|
@@ -240,12 +201,13 @@ async def test_call_peer_single_process(test_input, expected, handle, handler_na
|
|
|
assert not is_process_running(client_pid)
|
|
|
|
|
|
|
|
|
-async def run_server(handler_name, server_side, client_side, response_received):
|
|
|
+async def run_server(handler_name, server_side, response_received):
|
|
|
server = await P2P.create()
|
|
|
server_pid = server._child.pid
|
|
|
- await server.add_stream_handler(handler_name, handle_square)
|
|
|
assert is_process_running(server_pid)
|
|
|
|
|
|
+ await server.add_stream_handler(handler_name, handle_square_stream)
|
|
|
+
|
|
|
server_side.send(server.id)
|
|
|
server_side.send(await server.get_visible_maddrs())
|
|
|
while response_received.value == 0:
|
|
@@ -255,20 +217,19 @@ async def run_server(handler_name, server_side, client_side, response_received):
|
|
|
assert not is_process_running(server_pid)
|
|
|
|
|
|
|
|
|
-def server_target(handler_name, server_side, client_side, response_received):
|
|
|
- asyncio.run(run_server(handler_name, server_side, client_side, response_received))
|
|
|
+def server_target(handler_name, server_side, response_received):
|
|
|
+ asyncio.run(run_server(handler_name, server_side, response_received))
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_call_peer_different_processes():
|
|
|
handler_name = "square"
|
|
|
- test_input = 2
|
|
|
|
|
|
server_side, client_side = mp.Pipe()
|
|
|
response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
|
|
|
response_received.value = 0
|
|
|
|
|
|
- proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
|
|
|
+ proc = mp.Process(target=server_target, args=(handler_name, server_side, response_received))
|
|
|
proc.start()
|
|
|
|
|
|
peer_id = client_side.recv()
|
|
@@ -280,146 +241,95 @@ async def test_call_peer_different_processes():
|
|
|
|
|
|
await client.wait_for_at_least_n_peers(1)
|
|
|
|
|
|
- test_input_msgp = MSGPackSerializer.dumps(2)
|
|
|
- result_msgp = await client.call_peer_handler(peer_id, handler_name, test_input_msgp)
|
|
|
- result = MSGPackSerializer.loads(result_msgp)
|
|
|
- assert np.allclose(result, test_input ** 2)
|
|
|
+ _, reader, writer = await client.call_stream_handler(peer_id, handler_name)
|
|
|
+ await validate_square_stream(reader, writer)
|
|
|
+
|
|
|
response_received.value = 1
|
|
|
|
|
|
await client.shutdown()
|
|
|
assert not is_process_running(client_pid)
|
|
|
|
|
|
proc.join()
|
|
|
+ assert proc.exitcode == 0
|
|
|
|
|
|
|
|
|
-@pytest.mark.parametrize(
|
|
|
- "test_input,expected",
|
|
|
- [
|
|
|
- pytest.param(torch.tensor([2]), torch.tensor(4)),
|
|
|
- pytest.param(torch.tensor([[1.0, 2.0], [0.5, 0.1]]), torch.tensor([[1.0, 2.0], [0.5, 0.1]]) ** 2),
|
|
|
- ],
|
|
|
-)
|
|
|
@pytest.mark.asyncio
|
|
|
-async def test_call_peer_torch_square(test_input, expected, handler_name="handle"):
|
|
|
- handle = handle_square_torch
|
|
|
- server = await P2P.create()
|
|
|
- await server.add_stream_handler(handler_name, handle)
|
|
|
-
|
|
|
- nodes = await bootstrap_from([server])
|
|
|
- client = await P2P.create(initial_peers=nodes)
|
|
|
-
|
|
|
- await client.wait_for_at_least_n_peers(1)
|
|
|
-
|
|
|
- inp = serialize_torch_tensor(test_input).SerializeToString()
|
|
|
- result_pb = await client.call_peer_handler(server.id, handler_name, inp)
|
|
|
- result = runtime_pb2.Tensor()
|
|
|
- result.ParseFromString(result_pb)
|
|
|
- result = deserialize_torch_tensor(result)
|
|
|
- assert torch.allclose(result, expected)
|
|
|
+async def test_error_closes_connection():
|
|
|
+ async def handle_raising_error(_, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
|
|
+ with closing(writer):
|
|
|
+ command = await P2P.receive_raw_data(reader)
|
|
|
+ if command == b"raise_error":
|
|
|
+ raise Exception("The handler has failed")
|
|
|
+ else:
|
|
|
+ await P2P.send_raw_data(b"okay", writer)
|
|
|
|
|
|
- await server.shutdown()
|
|
|
- await client.shutdown()
|
|
|
-
|
|
|
-
|
|
|
-@pytest.mark.parametrize(
|
|
|
- "test_input,expected",
|
|
|
- [
|
|
|
- pytest.param([torch.tensor([1]), torch.tensor([2])], torch.tensor([3])),
|
|
|
- pytest.param(
|
|
|
- [torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([[1.1, 1.2], [1.3, 1.4]])],
|
|
|
- torch.tensor([[1.2, 1.4], [1.6, 1.8]]),
|
|
|
- ),
|
|
|
- ],
|
|
|
-)
|
|
|
-@pytest.mark.asyncio
|
|
|
-async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
|
|
|
- handle = handle_add_torch
|
|
|
server = await P2P.create()
|
|
|
- await server.add_stream_handler(handler_name, handle)
|
|
|
-
|
|
|
- nodes = await bootstrap_from([server])
|
|
|
- client = await P2P.create(initial_peers=nodes)
|
|
|
-
|
|
|
- await client.wait_for_at_least_n_peers(1)
|
|
|
+ server_pid = server._child.pid
|
|
|
+ assert is_process_running(server_pid)
|
|
|
|
|
|
- inp = [serialize_torch_tensor(i).SerializeToString() for i in test_input]
|
|
|
- inp_msgp = MSGPackSerializer.dumps(inp)
|
|
|
- result_pb = await client.call_peer_handler(server.id, handler_name, inp_msgp)
|
|
|
- result = runtime_pb2.Tensor()
|
|
|
- result.ParseFromString(result_pb)
|
|
|
- result = deserialize_torch_tensor(result)
|
|
|
- assert torch.allclose(result, expected)
|
|
|
-
|
|
|
- await server.shutdown()
|
|
|
- await client.shutdown()
|
|
|
+ handler_name = "handler"
|
|
|
+ await server.add_stream_handler(handler_name, handle_raising_error)
|
|
|
|
|
|
+ client = await P2P.create(initial_peers=await server.get_visible_maddrs())
|
|
|
+ client_pid = client._child.pid
|
|
|
+ assert is_process_running(client_pid)
|
|
|
|
|
|
-@pytest.mark.parametrize(
|
|
|
- "replicate",
|
|
|
- [
|
|
|
- pytest.param(False, id="primary"),
|
|
|
- pytest.param(True, id="replica"),
|
|
|
- ],
|
|
|
-)
|
|
|
-@pytest.mark.asyncio
|
|
|
-async def test_call_peer_error(replicate, handler_name="handle"):
|
|
|
- server_primary = await P2P.create()
|
|
|
- server = await replicate_if_needed(server_primary, replicate)
|
|
|
- await server.add_stream_handler(handler_name, handle_add_torch_with_exc)
|
|
|
+ await client.wait_for_at_least_n_peers(1)
|
|
|
|
|
|
- nodes = await bootstrap_from([server])
|
|
|
- client_primary = await P2P.create(initial_peers=nodes)
|
|
|
- client = await replicate_if_needed(client_primary, replicate)
|
|
|
+ _, reader, writer = await client.call_stream_handler(server.id, handler_name)
|
|
|
+ with closing(writer):
|
|
|
+ await P2P.send_raw_data(b"raise_error", writer)
|
|
|
+ with pytest.raises(asyncio.IncompleteReadError): # Means that the connection is closed
|
|
|
+ await P2P.receive_raw_data(reader)
|
|
|
|
|
|
- await client.wait_for_at_least_n_peers(1)
|
|
|
+ # Despite the handler raised an exception, the server did not crash and ready for next requests
|
|
|
+ assert is_process_running(server_pid)
|
|
|
|
|
|
- inp = [serialize_torch_tensor(i).SerializeToString() for i in [torch.zeros((2, 3)), torch.zeros((3, 2))]]
|
|
|
- inp_msgp = MSGPackSerializer.dumps(inp)
|
|
|
- result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
|
|
|
- assert result == b"something went wrong :("
|
|
|
+ _, reader, writer = await client.call_stream_handler(server.id, handler_name)
|
|
|
+ with closing(writer):
|
|
|
+ await P2P.send_raw_data(b"behave_normally", writer)
|
|
|
+ assert await P2P.receive_raw_data(reader) == b"okay"
|
|
|
|
|
|
- await server_primary.shutdown()
|
|
|
await server.shutdown()
|
|
|
- await client_primary.shutdown()
|
|
|
+ assert not is_process_running(server_pid)
|
|
|
+
|
|
|
await client.shutdown()
|
|
|
+ assert not is_process_running(client_pid)
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
-async def test_handlers_on_different_replicas(handler_name="handle"):
|
|
|
- def handler(arg, key):
|
|
|
- return key
|
|
|
+async def test_handlers_on_different_replicas():
|
|
|
+ async def handler(_, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, key: str) -> None:
|
|
|
+ with closing(writer):
|
|
|
+ await P2P.send_raw_data(key, writer)
|
|
|
|
|
|
server_primary = await P2P.create()
|
|
|
server_id = server_primary.id
|
|
|
- await server_primary.add_stream_handler(handler_name, partial(handler, key=b"primary"))
|
|
|
+ await server_primary.add_stream_handler("handle_primary", partial(handler, key=b"primary"))
|
|
|
|
|
|
server_replica1 = await replicate_if_needed(server_primary, True)
|
|
|
- await server_replica1.add_stream_handler(handler_name + "1", partial(handler, key=b"replica1"))
|
|
|
+ await server_replica1.add_stream_handler("handle1", partial(handler, key=b"replica1"))
|
|
|
|
|
|
server_replica2 = await replicate_if_needed(server_primary, True)
|
|
|
- await server_replica2.add_stream_handler(handler_name + "2", partial(handler, key=b"replica2"))
|
|
|
+ await server_replica2.add_stream_handler("handle2", partial(handler, key=b"replica2"))
|
|
|
|
|
|
- nodes = await bootstrap_from([server_primary])
|
|
|
- client = await P2P.create(initial_peers=nodes)
|
|
|
+ client = await P2P.create(initial_peers=await server_primary.get_visible_maddrs())
|
|
|
await client.wait_for_at_least_n_peers(1)
|
|
|
|
|
|
- result = await client.call_peer_handler(server_id, handler_name, b"1")
|
|
|
- assert result == b"primary"
|
|
|
-
|
|
|
- result = await client.call_peer_handler(server_id, handler_name + "1", b"2")
|
|
|
- assert result == b"replica1"
|
|
|
-
|
|
|
- result = await client.call_peer_handler(server_id, handler_name + "2", b"3")
|
|
|
- assert result == b"replica2"
|
|
|
+ for name, expected_key in [("handle_primary", b"primary"), ("handle1", b"replica1"), ("handle2", b"replica2")]:
|
|
|
+ _, reader, writer = await client.call_stream_handler(server_id, name)
|
|
|
+ with closing(writer):
|
|
|
+ assert await P2P.receive_raw_data(reader) == expected_key
|
|
|
|
|
|
await server_replica1.shutdown()
|
|
|
await server_replica2.shutdown()
|
|
|
|
|
|
- # Primary does not handle replicas protocols
|
|
|
- with pytest.raises(Exception):
|
|
|
- await client.call_peer_handler(server_id, handler_name + "1", b"")
|
|
|
- with pytest.raises(Exception):
|
|
|
- await client.call_peer_handler(server_id, handler_name + "2", b"")
|
|
|
+ # Primary does not handle replicas protocols after their shutdown
|
|
|
+
|
|
|
+ for name in ["handle1", "handle2"]:
|
|
|
+ _, reader, writer = await client.call_stream_handler(server_id, name)
|
|
|
+ with pytest.raises(asyncio.IncompleteReadError), closing(writer):
|
|
|
+ await P2P.receive_raw_data(reader)
|
|
|
|
|
|
await server_primary.shutdown()
|
|
|
await client.shutdown()
|