dbaranchuk 3 rokov pred
rodič
commit
5b06dc2255
1 zmenil súbory, kde vykonal 7 pridanie a 7 odobranie
  1. 7 7
      src/client/remote_model.py

+ 7 - 7
src/client/remote_model.py

@@ -34,7 +34,7 @@ class DistributedBloomConfig(BloomConfig):
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
-    num_prefix_tokens: int = 0  # a number of tokens for prompt tuning.
+    pre_seq_len: int = 0  # a number of tokens for prompt tuning.
 
 
 class DistributedBloomModel(BloomModel):
@@ -112,11 +112,11 @@ class DistributedBloomPrefix(DistributedBloomModel):
 
     def __init__(self, config):
         super().__init__(config)
-        assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
-        self.prefix_length = config.num_prefix_tokens
+        assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
+        self.pre_seq_len = config.pre_seq_len
 
-        self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
-        self.prefix_tokens = torch.arange(self.prefix_length).long()
+        self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
+        self.prefix_tokens = torch.arange(self.pre_seq_len).long()
 
     def get_prompt(self, batch_size):
         prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
@@ -163,7 +163,7 @@ class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
 
     def __init__(self, config: DistributedBloomConfig):
         BloomPreTrainedModel.__init__(self, config)
-        if config.num_prefix_tokens > 0:
+        if config.pre_seq_len > 0:
             self.transformer = DistributedBloomPrefix(config)
         else:
             self.transformer = DistributedBloomModel(config)
@@ -223,7 +223,7 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
 
     def __init__(self, config: DistributedBloomConfig):
         super().__init__(config)
-        if config.num_prefix_tokens > 0:
+        if config.pre_seq_len > 0:
             self.transformer = DistributedBloomPrefix(config)
         else:
             self.transformer = DistributedBloomModel(config)