|
@@ -1,6 +1,7 @@
|
|
|
import asyncio
|
|
|
import multiprocessing as mp
|
|
|
import subprocess
|
|
|
+from functools import partial
|
|
|
|
|
|
from hivemind.p2p.p2p_daemon_bindings.datastructures import ID
|
|
|
|
|
@@ -27,6 +28,10 @@ def is_process_running(pid: int) -> bool:
|
|
|
return subprocess.check_output(cmd, shell=True).decode('utf-8').strip() == RUNNING
|
|
|
|
|
|
|
|
|
+async def replicate_if_needed(p2p: P2P, replicate: bool):
|
|
|
+ return await P2P.replicate(p2p._daemon_listen_port, p2p._host_port) if replicate else p2p
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_daemon_killed_on_del():
|
|
|
p2p_daemon = await P2P.create()
|
|
@@ -38,6 +43,21 @@ async def test_daemon_killed_on_del():
|
|
|
assert not is_process_running(child_pid)
|
|
|
|
|
|
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_daemon_replica_does_not_affect_primary():
|
|
|
+ p2p_daemon = await P2P.create()
|
|
|
+ p2p_replica = await P2P.replicate(p2p_daemon._daemon_listen_port, p2p_daemon._host_port)
|
|
|
+
|
|
|
+ child_pid = p2p_daemon._child.pid
|
|
|
+ assert is_process_running(child_pid)
|
|
|
+
|
|
|
+ p2p_replica.__del__()
|
|
|
+ assert is_process_running(child_pid)
|
|
|
+
|
|
|
+ p2p_daemon.__del__()
|
|
|
+ assert not is_process_running(child_pid)
|
|
|
+
|
|
|
+
|
|
|
def handle_square(x):
|
|
|
return x ** 2
|
|
|
|
|
@@ -50,10 +70,15 @@ def handle_add(args):
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
- 'should_cancel', [True, False]
|
|
|
+ 'should_cancel,replicate', [
|
|
|
+ (True, False),
|
|
|
+ (True, True),
|
|
|
+ (False, False),
|
|
|
+ (False, True),
|
|
|
+ ]
|
|
|
)
|
|
|
@pytest.mark.asyncio
|
|
|
-async def test_call_unary_handler(should_cancel, handle_name="handle"):
|
|
|
+async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
|
|
|
handler_cancelled = False
|
|
|
|
|
|
async def ping_handler(request, context):
|
|
@@ -67,14 +92,16 @@ async def test_call_unary_handler(should_cancel, handle_name="handle"):
|
|
|
node_id=context.ours_id.encode(), rpc_port=context.ours_port),
|
|
|
sender_endpoint=context.handle_name, available=True)
|
|
|
|
|
|
- server = await P2P.create()
|
|
|
- server_pid = server._child.pid
|
|
|
+ server_primary = await P2P.create()
|
|
|
+ server = await replicate_if_needed(server_primary, replicate)
|
|
|
+ server_pid = server_primary._child.pid
|
|
|
await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
|
|
|
dht_pb2.PingResponse)
|
|
|
assert is_process_running(server_pid)
|
|
|
|
|
|
- client = await P2P.create()
|
|
|
- client_pid = client._child.pid
|
|
|
+ client_primary = await P2P.create()
|
|
|
+ client = await replicate_if_needed(client_primary, replicate)
|
|
|
+ client_pid = client_primary._child.pid
|
|
|
assert is_process_running(client_pid)
|
|
|
|
|
|
ping_request = dht_pb2.PingRequest(
|
|
@@ -100,10 +127,10 @@ async def test_call_unary_handler(should_cancel, handle_name="handle"):
|
|
|
assert not handler_cancelled
|
|
|
|
|
|
await server.stop_listening()
|
|
|
- server.__del__()
|
|
|
+ server_primary.__del__()
|
|
|
assert not is_process_running(server_pid)
|
|
|
|
|
|
- client.__del__()
|
|
|
+ client_primary.__del__()
|
|
|
assert not is_process_running(client_pid)
|
|
|
|
|
|
|
|
@@ -131,7 +158,6 @@ async def test_call_peer_single_process(test_input, handle, handler_name="handle
|
|
|
result = await client.call_peer_handler(server.id, handler_name, test_input)
|
|
|
assert result == handle(test_input)
|
|
|
|
|
|
- await server.stop_listening()
|
|
|
server.__del__()
|
|
|
assert not is_process_running(server_pid)
|
|
|
|
|
@@ -188,30 +214,83 @@ async def test_call_peer_different_processes():
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
- "test_input,handle",
|
|
|
+ "test_input,handle,replicate",
|
|
|
[
|
|
|
- pytest.param(np.random.randn(2, 3), handle_square, id="square"),
|
|
|
- pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, id="add"),
|
|
|
+ pytest.param(np.random.randn(2, 3), handle_square, False, id="square_primary"),
|
|
|
+ pytest.param(np.random.randn(2, 3), handle_square, True, id="square_replica"),
|
|
|
+ pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, False, id="add_primary"),
|
|
|
+ pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, True, id="add_replica"),
|
|
|
]
|
|
|
)
|
|
|
@pytest.mark.asyncio
|
|
|
-async def test_call_peer_numpy(test_input, handle, handler_name="handle"):
|
|
|
- server = await P2P.create()
|
|
|
+async def test_call_peer_numpy(test_input, handle, replicate, handler_name="handle"):
|
|
|
+ server_primary = await P2P.create()
|
|
|
+ server = await replicate_if_needed(server_primary, replicate)
|
|
|
await server.add_stream_handler(handler_name, handle)
|
|
|
- client = await P2P.create()
|
|
|
+ client_primary = await P2P.create()
|
|
|
+ client = await replicate_if_needed(client_primary, replicate)
|
|
|
|
|
|
await asyncio.sleep(1)
|
|
|
result = await client.call_peer_handler(server.id, handler_name, test_input)
|
|
|
assert np.allclose(result, handle(test_input))
|
|
|
|
|
|
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "replicate",
|
|
|
+ [
|
|
|
+ pytest.param(False, id="primary"),
|
|
|
+ pytest.param(True, id="replica"),
|
|
|
+ ]
|
|
|
+)
|
|
|
@pytest.mark.asyncio
|
|
|
-async def test_call_peer_error(handler_name="handle"):
|
|
|
- server = await P2P.create()
|
|
|
+async def test_call_peer_error(replicate, handler_name="handle"):
|
|
|
+ server_primary = await P2P.create()
|
|
|
+ server = await replicate_if_needed(server_primary, replicate)
|
|
|
await server.add_stream_handler(handler_name, handle_add)
|
|
|
- client = await P2P.create()
|
|
|
+ client_primary = await P2P.create()
|
|
|
+ client = await replicate_if_needed(client_primary, replicate)
|
|
|
|
|
|
await asyncio.sleep(1)
|
|
|
result = await client.call_peer_handler(server.id, handler_name,
|
|
|
[np.zeros((2, 3)), np.zeros((3, 2))])
|
|
|
assert type(result) == ValueError
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_handlers_on_different_replicas(handler_name="handle"):
|
|
|
+ def handler(arg, key):
|
|
|
+ return key
|
|
|
+
|
|
|
+ server_primary = await P2P.create()
|
|
|
+ server_id = server_primary.id
|
|
|
+ await server_primary.add_stream_handler(handler_name, partial(handler, key="primary"))
|
|
|
+
|
|
|
+ server_replica1 = await replicate_if_needed(server_primary, True)
|
|
|
+ await server_replica1.add_stream_handler(handler_name + "1", partial(handler, key="replica1"))
|
|
|
+
|
|
|
+ server_replica2 = await replicate_if_needed(server_primary, True)
|
|
|
+ await server_replica2.add_stream_handler(handler_name + "2", partial(handler, key="replica2"))
|
|
|
+
|
|
|
+ client = await P2P.create()
|
|
|
+ await asyncio.sleep(1)
|
|
|
+ result = await client.call_peer_handler(server_id, handler_name, "")
|
|
|
+ assert result == "primary"
|
|
|
+
|
|
|
+ result = await client.call_peer_handler(server_id, handler_name + "1", "")
|
|
|
+ assert result == "replica1"
|
|
|
+
|
|
|
+ result = await client.call_peer_handler(server_id, handler_name + "2", "")
|
|
|
+ assert result == "replica2"
|
|
|
+
|
|
|
+ await server_replica1.stop_listening()
|
|
|
+ await server_replica2.stop_listening()
|
|
|
+
|
|
|
+ # Primary does not handle replicas protocols
|
|
|
+ with pytest.raises(P2P.IncompleteRead):
|
|
|
+ await client.call_peer_handler(server_id, handler_name + "1", "")
|
|
|
+ with pytest.raises(P2P.IncompleteRead):
|
|
|
+ await client.call_peer_handler(server_id, handler_name + "2", "")
|
|
|
+
|
|
|
+ await server_primary.stop_listening()
|
|
|
+ server_primary.__del__()
|
|
|
+ client.__del__()
|