Explorar el Código

set default DHT prefix

justheuristic hace 3 años
padre
commit
90d65e58aa
Se han modificado 1 ficheros con 7 adiciones y 2 borrados
  1. 7 2
      cli/convert_model.py

+ 7 - 2
cli/convert_model.py

@@ -9,6 +9,9 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from huggingface_hub import Repository
 from tqdm.auto import tqdm
 
+from src import BloomModel
+from src.client.remote_model import DistributedBloomConfig
+
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
@@ -44,10 +47,12 @@ if __name__ == "__main__":
         raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
 
     logger.info(f"Loading source model {args.model} (this may take a few minutes)")
-    config = transformers.AutoConfig.from_pretrained(
+    config = DistributedBloomConfig.from_pretrained(
         args.model, use_auth_token=args.use_auth_token, revision=args.revision
     )
-    model = transformers.AutoModel.from_pretrained(
+    config.dht_prefix = args.model
+
+    model = BloomModel.from_pretrained(
         args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
     )
     tokenizer = transformers.AutoTokenizer.from_pretrained(