Jelajahi Sumber

Infer prefix by defaukt

justheuristic 3 tahun lalu
induk
melakukan
9c492bbe8c
3 mengubah file dengan 12 tambahan dan 3 penghapusan
  1. 2 1
      cli/run_server.py
  2. 2 1
      src/client/remote_model.py
  3. 8 1
      src/server/server.py

+ 2 - 1
cli/run_server.py

@@ -14,11 +14,12 @@ def main():
     parser = configargparse.ArgParser(default_config_files=["config.yml"])
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
 
-    parser.add_argument('--prefix', type=str, required=True, help="Announce all blocks with this prefix")
     parser.add_argument('--converted_model_name_or_path', type=str, default='bigscience/test-bloomd-6b3',
                         help="path or name of a pretrained model, converted with cli/convert_model.py (see README.md)")
     parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
     parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
+    parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
+                                                                 "use the same name as in the converted model.")
     parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
                         help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
     parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,

+ 2 - 1
src/client/remote_model.py

@@ -26,7 +26,8 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
 
     @classmethod
     def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
-        assert 'initial_peers' in kwargs
+        if 'initial_peers' not in kwargs:
+            raise ValueError("Please specify initial_peers=...")
         dht = hivemind.DHT(
             initial_peers=kwargs.pop('initial_peers'), client_mode=kwargs.pop('client_mode', True),
             start=True)

+ 8 - 1
src/server/server.py

@@ -14,6 +14,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 from src import declare_active_modules
 from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
+from src.data_structures import UID_DELIMITER, CHAIN_DELIMITER
 from src.server.backend import TransformerBackend
 from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler
@@ -84,7 +85,7 @@ class Server(threading.Thread):
     @classmethod
     def create(
         cls,
-        prefix: str,
+        prefix: Optional[str],
         converted_model_name_or_path: str,
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
@@ -108,6 +109,12 @@ class Server(threading.Thread):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
+        if prefix is None:
+            prefix = converted_model_name_or_path
+            assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix,\
+                f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); " \
+                f"Please specify --prefix manually when starting a server"
+            logger.info(f"Automatic dht prefix: {prefix}")
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
         dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]