فهرست منبع

Add identity_path option for MoE.Server runners (#484)

add identity_path option for moe.server runners
Pavel Samygin 3 سال پیش
والد
کامیت
6c56a8732a
2فایلهای تغییر یافته به همراه84 افزوده شده و 0 حذف شده
  1. 1 0
      hivemind/hivemind_cli/run_server.py
  2. 83 0
      tests/test_start_server.py

+ 1 - 0
hivemind/hivemind_cli/run_server.py

@@ -69,6 +69,7 @@ def main():
 
     parser.add_argument('--custom_module_path', type=str, required=False,
                         help='Path of a file with custom nn.modules, wrapped into special decorator')
+    parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
 
     # fmt:on
     args = vars(parser.parse_args())

+ 83 - 0
tests/test_start_server.py

@@ -0,0 +1,83 @@
+import os
+import re
+from subprocess import PIPE, Popen
+from tempfile import TemporaryDirectory
+
+from hivemind.moe.server import background_server
+
+
+def test_background_server_identity_path():
+    with TemporaryDirectory() as tempdir:
+        id_path = os.path.join(tempdir, "id")
+
+        with background_server(num_experts=1, identity_path=id_path) as server_info_1, background_server(
+            num_experts=1, identity_path=id_path
+        ) as server_info_2, background_server(num_experts=1, identity_path=None) as server_info_3:
+
+            assert server_info_1.peer_id == server_info_2.peer_id
+            assert server_info_1.peer_id != server_info_3.peer_id
+            assert server_info_3.peer_id == server_info_3.peer_id
+
+
+def test_cli_run_server_identity_path():
+    pattern = r"Running DHT node on \[(.+)\],"
+
+    with TemporaryDirectory() as tempdir:
+        id_path = os.path.join(tempdir, "id")
+
+        server_1_proc = Popen(
+            ["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
+            stderr=PIPE,
+            text=True,
+            encoding="utf-8",
+        )
+
+        # Skip line "Generating new identity (libp2p private key) in {path to file}"
+        line = server_1_proc.stderr.readline()
+        line = server_1_proc.stderr.readline()
+        addrs_1 = set(re.search(pattern, line).group(1).split(", "))
+        ids_1 = set(a.split("/")[-1] for a in addrs_1)
+
+        assert len(ids_1) == 1
+
+        server_2_proc = Popen(
+            ["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
+            stderr=PIPE,
+            text=True,
+            encoding="utf-8",
+        )
+
+        line = server_2_proc.stderr.readline()
+        addrs_2 = set(re.search(pattern, line).group(1).split(", "))
+        ids_2 = set(a.split("/")[-1] for a in addrs_2)
+
+        assert len(ids_2) == 1
+
+        server_3_proc = Popen(
+            ["hivemind-server", "--num_experts", "1"],
+            stderr=PIPE,
+            text=True,
+            encoding="utf-8",
+        )
+
+        line = server_3_proc.stderr.readline()
+        addrs_3 = set(re.search(pattern, line).group(1).split(", "))
+        ids_3 = set(a.split("/")[-1] for a in addrs_3)
+
+        assert len(ids_3) == 1
+
+        assert ids_1 == ids_2
+        assert ids_1 != ids_3
+        assert ids_2 != ids_3
+
+        assert addrs_1 != addrs_2
+        assert addrs_1 != addrs_3
+        assert addrs_2 != addrs_3
+
+        server_1_proc.terminate()
+        server_2_proc.terminate()
+        server_3_proc.terminate()
+
+        server_1_proc.wait()
+        server_2_proc.wait()
+        server_3_proc.wait()