Browse Source

#204 P2P replica mode (#205)

* #204 P2P replica mode

* #204 rename replica->replicate
MaximKsh 4 years ago
parent
commit
2dbee5964c
2 changed files with 111 additions and 18 deletions
  1. 14 0
      hivemind/p2p/p2p_daemon.py
  2. 97 18
      tests/test_p2p_daemon.py

+ 14 - 0
hivemind/p2p/p2p_daemon.py

@@ -73,6 +73,20 @@ class P2P(object):
             break
         return self
 
+    @classmethod
+    async def replicate(cls, daemon_listen_port: int, host_port: int):
+        self = cls()
+        # There is no child under control
+        # Use external already running p2pd
+        self._child = None
+        self._assign_daemon_ports(host_port, daemon_listen_port)
+        self._client_listen_port = find_open_port()
+        self._client = p2pclient.Client(
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
+        await self._identify_client(0)
+        return self
+
     def _initialize(self, proc_args: tp.List[str]) -> None:
         proc_args = copy.deepcopy(proc_args)
         proc_args.extend(self._make_process_args(

+ 97 - 18
tests/test_p2p_daemon.py

@@ -1,6 +1,7 @@
 import asyncio
 import multiprocessing as mp
 import subprocess
+from functools import partial
 
 from hivemind.p2p.p2p_daemon_bindings.datastructures import ID
 
@@ -27,6 +28,10 @@ def is_process_running(pid: int) -> bool:
     return subprocess.check_output(cmd, shell=True).decode('utf-8').strip() == RUNNING
 
 
+async def replicate_if_needed(p2p: P2P, replicate: bool):
+    return await P2P.replicate(p2p._daemon_listen_port, p2p._host_port) if replicate else p2p
+
+
 @pytest.mark.asyncio
 async def test_daemon_killed_on_del():
     p2p_daemon = await P2P.create()
@@ -38,6 +43,21 @@ async def test_daemon_killed_on_del():
     assert not is_process_running(child_pid)
 
 
+@pytest.mark.asyncio
+async def test_daemon_replica_does_not_affect_primary():
+    p2p_daemon = await P2P.create()
+    p2p_replica = await P2P.replicate(p2p_daemon._daemon_listen_port, p2p_daemon._host_port)
+
+    child_pid = p2p_daemon._child.pid
+    assert is_process_running(child_pid)
+
+    p2p_replica.__del__()
+    assert is_process_running(child_pid)
+
+    p2p_daemon.__del__()
+    assert not is_process_running(child_pid)
+
+
 def handle_square(x):
     return x ** 2
 
@@ -50,10 +70,15 @@ def handle_add(args):
 
 
 @pytest.mark.parametrize(
-    'should_cancel', [True, False]
+    'should_cancel,replicate', [
+        (True, False),
+        (True, True),
+        (False, False),
+        (False, True),
+    ]
 )
 @pytest.mark.asyncio
-async def test_call_unary_handler(should_cancel, handle_name="handle"):
+async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
     handler_cancelled = False
 
     async def ping_handler(request, context):
@@ -67,14 +92,16 @@ async def test_call_unary_handler(should_cancel, handle_name="handle"):
                 node_id=context.ours_id.encode(), rpc_port=context.ours_port),
             sender_endpoint=context.handle_name, available=True)
 
-    server = await P2P.create()
-    server_pid = server._child.pid
+    server_primary = await P2P.create()
+    server = await replicate_if_needed(server_primary, replicate)
+    server_pid = server_primary._child.pid
     await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
                                    dht_pb2.PingResponse)
     assert is_process_running(server_pid)
 
-    client = await P2P.create()
-    client_pid = client._child.pid
+    client_primary = await P2P.create()
+    client = await replicate_if_needed(client_primary, replicate)
+    client_pid = client_primary._child.pid
     assert is_process_running(client_pid)
 
     ping_request = dht_pb2.PingRequest(
@@ -100,10 +127,10 @@ async def test_call_unary_handler(should_cancel, handle_name="handle"):
         assert not handler_cancelled
 
     await server.stop_listening()
-    server.__del__()
+    server_primary.__del__()
     assert not is_process_running(server_pid)
 
-    client.__del__()
+    client_primary.__del__()
     assert not is_process_running(client_pid)
 
 
@@ -131,7 +158,6 @@ async def test_call_peer_single_process(test_input, handle, handler_name="handle
     result = await client.call_peer_handler(server.id, handler_name, test_input)
     assert result == handle(test_input)
 
-    await server.stop_listening()
     server.__del__()
     assert not is_process_running(server_pid)
 
@@ -188,30 +214,83 @@ async def test_call_peer_different_processes():
 
 
 @pytest.mark.parametrize(
-    "test_input,handle",
+    "test_input,handle,replicate",
     [
-        pytest.param(np.random.randn(2, 3), handle_square, id="square"),
-        pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, id="add"),
+        pytest.param(np.random.randn(2, 3), handle_square, False, id="square_primary"),
+        pytest.param(np.random.randn(2, 3), handle_square, True, id="square_replica"),
+        pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, False, id="add_primary"),
+        pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, True, id="add_replica"),
     ]
 )
 @pytest.mark.asyncio
-async def test_call_peer_numpy(test_input, handle, handler_name="handle"):
-    server = await P2P.create()
+async def test_call_peer_numpy(test_input, handle, replicate, handler_name="handle"):
+    server_primary = await P2P.create()
+    server = await replicate_if_needed(server_primary, replicate)
     await server.add_stream_handler(handler_name, handle)
-    client = await P2P.create()
+    client_primary = await P2P.create()
+    client = await replicate_if_needed(client_primary, replicate)
 
     await asyncio.sleep(1)
     result = await client.call_peer_handler(server.id, handler_name, test_input)
     assert np.allclose(result, handle(test_input))
 
 
+@pytest.mark.parametrize(
+    "replicate",
+    [
+        pytest.param(False, id="primary"),
+        pytest.param(True, id="replica"),
+    ]
+)
 @pytest.mark.asyncio
-async def test_call_peer_error(handler_name="handle"):
-    server = await P2P.create()
+async def test_call_peer_error(replicate, handler_name="handle"):
+    server_primary = await P2P.create()
+    server = await replicate_if_needed(server_primary, replicate)
     await server.add_stream_handler(handler_name, handle_add)
-    client = await P2P.create()
+    client_primary = await P2P.create()
+    client = await replicate_if_needed(client_primary, replicate)
 
     await asyncio.sleep(1)
     result = await client.call_peer_handler(server.id, handler_name,
                                             [np.zeros((2, 3)), np.zeros((3, 2))])
     assert type(result) == ValueError
+
+
+@pytest.mark.asyncio
+async def test_handlers_on_different_replicas(handler_name="handle"):
+    def handler(arg, key):
+        return key
+
+    server_primary = await P2P.create()
+    server_id = server_primary.id
+    await server_primary.add_stream_handler(handler_name, partial(handler, key="primary"))
+
+    server_replica1 = await replicate_if_needed(server_primary, True)
+    await server_replica1.add_stream_handler(handler_name + "1", partial(handler, key="replica1"))
+
+    server_replica2 = await replicate_if_needed(server_primary, True)
+    await server_replica2.add_stream_handler(handler_name + "2", partial(handler, key="replica2"))
+
+    client = await P2P.create()
+    await asyncio.sleep(1)
+    result = await client.call_peer_handler(server_id, handler_name, "")
+    assert result == "primary"
+
+    result = await client.call_peer_handler(server_id, handler_name + "1", "")
+    assert result == "replica1"
+
+    result = await client.call_peer_handler(server_id, handler_name + "2", "")
+    assert result == "replica2"
+
+    await server_replica1.stop_listening()
+    await server_replica2.stop_listening()
+
+    # Primary does not handle replicas protocols
+    with pytest.raises(P2P.IncompleteRead):
+        await client.call_peer_handler(server_id, handler_name + "1", "")
+    with pytest.raises(P2P.IncompleteRead):
+        await client.call_peer_handler(server_id, handler_name + "2", "")
+
+    await server_primary.stop_listening()
+    server_primary.__del__()
+    client.__del__()