Răsfoiți Sursa

host_maddrs and announce_maddrs to cli, minor review issues fixes

Pavel Samygin 3 ani în urmă
părinte
comite
903271fd52

+ 11 - 14
benchmarks/benchmark_throughput.py

@@ -6,13 +6,13 @@ import time
 
 import torch
 
-import hivemind
-from hivemind import P2P
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpertWorker
-from hivemind.moe.server import layers
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
+from hivemind.moe.server import ExpertBackend, Server, layers
+from hivemind.p2p import P2P, PeerInfo
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -46,13 +46,11 @@ def client_process(
 
     p2p = RemoteExpertWorker.run_coroutine(P2P.create())
     RemoteExpertWorker.run_coroutine(p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
-    experts = [
-        hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info, p2p=p2p) for i in range(num_experts)
-    ]
+    experts = [RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info, p2p=p2p) for i in range(num_experts)]
 
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
-        for batch_i in range(num_batches):
+        for _ in range(num_batches):
             expert = random.choice(experts)
             out = expert(dummy_batch)
             if backprop:
@@ -88,7 +86,7 @@ def benchmark_throughput(
 
     try:
         server_dht = DHT(start=True)
-        server_dht_peer_info = hivemind.PeerInfo(
+        server_dht_peer_info = PeerInfo(
             peer_id=server_dht.peer_id,
             addrs=[addr.decapsulate("/p2p/" + addr.get("p2p")) for addr in server_dht.get_visible_maddrs()],
         )
@@ -121,17 +119,17 @@ def benchmark_throughput(
         experts = {}
         for i in range(num_experts):
             expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
-            experts[f"expert.{i}"] = hivemind.ExpertBackend(
+            experts[f"expert.{i}"] = ExpertBackend(
                 name=f"expert.{i}",
                 expert=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
-                args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
-                outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
+                args_schema=(BatchTensorDescriptor(hid_dim),),
+                outputs_schema=BatchTensorDescriptor(hid_dim),
                 max_batch_size=max_batch_size,
             )
         timestamps["created_experts"] = time.perf_counter()
 
-        server = hivemind.moe.Server(
+        server = Server(
             dht=server_dht,
             expert_backends=experts,
             num_connection_handlers=num_handlers,
@@ -251,7 +249,6 @@ if __name__ == "__main__":
             num_clients=1,
             num_handlers=1,
             num_batches_per_client=args.num_batches_per_client,
-            batch_size=1024,
         )
     elif args.preset == "nop":
         benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)

+ 3 - 3
hivemind/dht/dht.py

@@ -55,7 +55,7 @@ class DHT(mp.Process):
         **kwargs,
     ):
         self._parent_pid = os.getpid()
-        self._my_pid = os.getpid()
+        self._origin_pid = os.getpid()
         super().__init__()
 
         if not (
@@ -311,8 +311,8 @@ class DHT(mp.Process):
         The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
         """
 
-        if self._p2p_replica is None or self._my_pid != os.getpid():
-            self._my_pid = os.getpid()
+        if self._p2p_replica is None or self._origin_pid != os.getpid():
+            self._origin_pid = os.getpid()
             daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
             self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
         return self._p2p_replica

+ 5 - 0
hivemind/hivemind_cli/run_server.py

@@ -31,6 +31,11 @@ def main():
                         help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'")
     parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
 
+    parser.add_argument('--host_maddrs', type=list, nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
+                        help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
+    parser.add_argument('--announce_maddrs', type=list, nargs='+', default=None, required=False,
+                        help='Visible multiaddrs the host announces for external connections from other p2p instances')
+
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
                         help='server will use this many processes to handle incoming requests')
     parser.add_argument('--min_batch_size', type=int, default=1,

+ 0 - 1
hivemind/moe/__init__.py

@@ -1,6 +1,5 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import (
-    ConnectionHandler,
     ExpertBackend,
     Server,
     background_server,

+ 0 - 1
hivemind/moe/server/__init__.py

@@ -1,4 +1,3 @@
-from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.dht_handler import declare_experts, get_experts
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.layers import register_expert_class

+ 3 - 1
hivemind/moe/server/server.py

@@ -109,6 +109,7 @@ class Server(threading.Thread):
         custom_module_path=None,
         *,
         start: bool,
+        **kwargs,
     ) -> Server:
         """
         Instantiate a server with several identical experts. See argparse comments below for details
@@ -140,12 +141,13 @@ class Server(threading.Thread):
 
         :param start: if True, starts server right away and returns when server is ready for requests
         :param stats_report_interval: interval between two reports of batch processing performance statistics
+        :param kwargs: any other params will be forwarded to DHT upon creation
         """
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
         assert expert_cls in name_to_block
 
-        dht = DHT(initial_peers=initial_peers, start=True)
+        dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
         logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")