5
0
justheuristic 3 жил өмнө
parent
commit
5695897620

+ 5 - 5
cli/convert_model.py

@@ -10,8 +10,8 @@ from huggingface_hub import Repository
 from tqdm.auto import tqdm
 
 from src import BloomModel
-from src.client.remote_model import DistributedBloomConfig
-
+from src.client import DistributedBloomConfig
+from src.bloom.from_pretrained import CLIENT_BRANCH, BLOCK_BRANCH_PREFIX
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
@@ -26,9 +26,9 @@ if __name__ == "__main__":
     parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
     parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
     parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
-    parser.add_argument("--client_branch", type=str, default="client", help="Save client version to this branch")
+    parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
     parser.add_argument(
-        "--block_branch_prefix", type=str, default="block_", help="Save blocks to branches with this prefix"
+        "--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
     )
     parser.add_argument(
         "--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
@@ -50,7 +50,7 @@ if __name__ == "__main__":
     config = DistributedBloomConfig.from_pretrained(
         args.model, use_auth_token=args.use_auth_token, revision=args.revision
     )
-    config.dht_prefix = args.model
+    config.dht_prefix = args.output_repo
 
     model = BloomModel.from_pretrained(
         args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]

+ 3 - 4
src/server/server.py

@@ -6,14 +6,13 @@ from typing import Dict, Optional, Sequence, Union
 
 import torch
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
-from hivemind.moe.server.dht_handler import DHTHandlerThread
 from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src import declare_active_modules
-from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
+from src import declare_active_modules, BloomConfig
+from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
 from src.server.backend import TransformerBackend
 from src.server.cache import MemoryCache
@@ -140,7 +139,7 @@ class Server(threading.Thread):
             assert num_blocks is not None
             block_indices = range(num_blocks)  # TODO replace with proper load balancing
 
-        block_config = DistributedBloomConfig.from_pretrained(
+        block_config = BloomConfig.from_pretrained(
             converted_model_name_or_path, use_auth_token=use_auth_token
         )