|
@@ -6,7 +6,7 @@ import torch.backends.quantized
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
import transformers
|
|
import transformers
|
|
from hivemind.utils.logging import get_logger
|
|
from hivemind.utils.logging import get_logger
|
|
-from huggingface_hub import Repository
|
|
|
|
|
|
+from huggingface_hub import HfApi, Repository
|
|
from tqdm.auto import tqdm
|
|
from tqdm.auto import tqdm
|
|
from transformers.models.bloom.modeling_bloom import BloomModel
|
|
from transformers.models.bloom.modeling_bloom import BloomModel
|
|
|
|
|
|
@@ -66,6 +66,8 @@ def main():
|
|
)
|
|
)
|
|
os.makedirs(args.output_path, exist_ok=True)
|
|
os.makedirs(args.output_path, exist_ok=True)
|
|
|
|
|
|
|
|
+ api = HfApi(token=args.use_auth_token)
|
|
|
|
+ api.create_repo(args.output_repo, repo_type="model", exist_ok=True)
|
|
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
|
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
|
repo.git_pull()
|
|
repo.git_pull()
|
|
|
|
|