Bläddra i källkod

Make model uploading use access token from authorizer (#7)

Alexander Borzunov 3 år sedan
förälder
incheckning
09240991cc
2 ändrade filer med 20 tillägg och 21 borttagningar
  1. 9 11
      run_aux_peer.py
  2. 11 10
      task.py

+ 9 - 11
run_aux_peer.py

@@ -1,5 +1,4 @@
 #!/usr/bin/env python
-import threading
 import time
 
 import torch
@@ -27,14 +26,14 @@ class CheckpointHandler:
         self.local_path = peer_args.local_path
         self.upload_interval = peer_args.upload_interval
         if self.upload_interval is not None:
-            self.token = HfFolder.get_token()
+            assert task.authorizer is not None, 'Model uploading needs Hugging Face auth to be enabled'
             self.repo = Repository(
                 local_dir=self.local_path,
                 clone_from=peer_args.repo_url,
-                use_auth_token=self.token,
+                use_auth_token=task.authorizer.hf_user_access_token,
             )
+            self.last_upload_time = None
         self.previous_step = -1
-        self.previous_timestamp = time.time()
 
     def should_save_state(self, cur_step):
         if self.save_checkpoint_step_interval is None:
@@ -52,17 +51,18 @@ class CheckpointHandler:
     def is_time_to_upload(self):
         if self.upload_interval is None:
             return False
-        elif time.time() - self.previous_timestamp >= self.upload_interval:
+        elif self.last_upload_time is None or time.time() - self.last_upload_time >= self.upload_interval:
             return True
         else:
             return False
 
     def upload_checkpoint(self, current_loss):
+        self.last_upload_time = time.time()
+
         logger.info("Saving model")
         torch.save(self.task.model.state_dict(), f"{self.local_path}/model_state.pt")
         logger.info("Saving optimizer")
         torch.save(self.task.collaborative_optimizer.state_dict(), f"{self.local_path}/optimizer_state.pt")
-        self.previous_timestamp = time.time()
         logger.info("Started uploading to Model Hub")
         try:
             # We start by pulling the remote changes (for example a change in the readme file)
@@ -71,11 +71,9 @@ class CheckpointHandler:
             # Then we add / commmit and push the changes
             self.repo.push_to_hub(commit_message=f"Epoch {self.task.collaborative_optimizer.local_epoch}, loss {current_loss:.3f}")
             logger.info("Finished uploading to Model Hub")
-        except OSError as e:
-            # There may be an error if a push arrives on the remote branch after the pull performed just above it. In
-            # this case the changes will be pushed with the next commit.
-            logger.error(f'The push to hub operation failed with error "{e}"')
-        
+        except Exception:
+            logger.exception("Uploading the checkpoint to HF Model Hub failed:")
+            logger.warning("Ensure that your access token is valid and has WRITE permissions")
 
 
 def assist_averaging_in_background(task: TrainingTask, peer_args: AuxiliaryPeerArguments):

+ 11 - 10
task.py

@@ -5,13 +5,12 @@ from pathlib import Path
 
 import hivemind
 import torch
+import torch.nn as nn
 import transformers
 from dalle_pytorch import DALLE
 from dalle_pytorch.vae import VQGanVAE
 from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
 from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
-from torch import nn
-from transformers import training_args
 
 import utils
 from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
@@ -45,12 +44,14 @@ class ModelWrapper(nn.Module):
 
 class TrainingTask:
     """A container that defines the training config, model, tokenizer, optimizer and other local training utilities"""
-    _dht = _collaborative_optimizer = _training_dataset = None
+    _authorizer = _dht = _collaborative_optimizer = _training_dataset = None
 
 
     def __init__(
             self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments):
         self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
+        self.trainer_args.run_name = self.authorizer.username  # For wandb
+
         self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
         transformers.set_seed(trainer_args.seed)  # seed used for initialization
 
@@ -91,15 +92,15 @@ class TrainingTask:
             logger.info(f"Loading model from {latest_checkpoint_dir}")
             self.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
 
+    @property
+    def authorizer(self):
+        if self._authorizer is None and self.peer_args.authorize:
+            self._authorizer = authorize_with_huggingface()
+        return self._authorizer
+
     @property
     def dht(self):
         if self._dht is None:
-            if self.peer_args.authorize:
-                authorizer = authorize_with_huggingface()
-                self.trainer_args.run_name = authorizer.username  # For wandb
-            else:
-                authorizer = None
-
             self._dht = hivemind.DHT(
                 start=True,
                 initial_peers=self.peer_args.initial_peers,
@@ -109,7 +110,7 @@ class TrainingTask:
                 use_ipfs=self.peer_args.use_ipfs,
                 record_validators=self.validators,
                 identity_path=self.peer_args.identity_path,
-                authorizer=authorizer,
+                authorizer=self.authorizer,
             )
             if self.peer_args.client_mode:
                 logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")