ソースを参照

Fix DHT listening address, allow starting a server with no experts (#121)

* Fix DHT listening address

* Allow starting the server with 0 experts

* Change NONE to CompressionType.NONE
Max Ryabinin 4 年 前
コミット
82c3e51131
3 ファイル変更11 行追加14 行削除
  1. 1 1
      hivemind/__init__.py
  2. 7 6
      hivemind/server/__init__.py
  3. 3 7
      scripts/run_server.py

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.12'
+__version__ = '0.8.13'

+ 7 - 6
hivemind/server/__init__.py

@@ -106,13 +106,13 @@ class Server(threading.Thread):
         dht = None
         if not no_dht:
             logger.info(f"Bootstrapping DHT node, initial peers = {initial_peers}")
-            dht = hivemind.DHT(initial_peers=initial_peers, start=True,
-                               listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}")
+            dht_endpoint = replace_port(listen_on, dht_port or hivemind.find_open_port())
+            dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
             if verbose:
                 logger.info(f"Running dht node on port {dht.port}")
 
         # get expert uids
-        assert (expert_pattern is None and num_experts is None) or (expert_uids is None), \
+        assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
             "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
         if expert_uids is None:
             assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
@@ -162,9 +162,10 @@ class Server(threading.Thread):
             if not self.dht.is_alive():
                 self.dht.run_in_background(await_ready=True)
 
-            dht_handler_thread = DHTHandlerThread(
-                experts=self.experts, dht=self.dht, endpoint=self.listen_on, update_period=self.update_period)
-            dht_handler_thread.start()
+            if self.experts:
+                dht_handler_thread = DHTHandlerThread(
+                    experts=self.experts, dht=self.dht, endpoint=self.listen_on, update_period=self.update_period)
+                dht_handler_thread.start()
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
 

+ 3 - 7
scripts/run_server.py

@@ -1,14 +1,11 @@
 from functools import partial
 
 import configargparse
-import resource
-
 import torch
 
+from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.server import Server
 from hivemind.utils.threading import increase_file_limit
-from hivemind.proto.runtime_pb2 import CompressionType
-
 
 
 def main():
@@ -16,7 +13,7 @@ def main():
     parser = configargparse.ArgParser(default_config_files=["config.yml"])
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
     parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
-                        help="'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6")
+                        help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
     parser.add_argument('--num_experts', type=int, default=None, required=False, help="The number of experts to serve")
     parser.add_argument('--expert_pattern', type=str, default=None, required=False,
                         help='all expert uids will follow this pattern, e.g. "myexpert.[0:256].[0:1024]" will sample random expert uids'
@@ -60,13 +57,12 @@ def main():
         increase_file_limit()
 
     compression_name = args.pop("compression")
-    compression = CompressionType.NONE
     if compression_name == "MEANSTD":
         compression = CompressionType.MEANSTD_LAST_AXIS_FLOAT16
     elif compression_name == "FLOAT16":
         compression = CompressionType.FLOAT16
     else:
-        compression = getattr(CompressionType, compression_name)
+        compression = getattr(CompressionType, CompressionType.NONE)
 
     try:
         server = Server.create(**args, optim_cls=optim_cls, start=True, verbose=True, compression=compression)