Explorar o código

Fix prompt tuning after #464 (#501)

Unfortunately, running inference in models with `"ptune" in config.tuning_mode` was broken after #464.
Alexander Borzunov hai 1 ano
pai
achega
d40eb6c701

+ 3 - 2
src/petals/client/remote_generation.py

@@ -87,10 +87,11 @@ class RemoteGenerationMixin(_SkipTokensMixin):
                 max_new_tokens is None
             ), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
 
+            session_max_length = self.transformer.config.pre_seq_len
             if max_length is not None:
-                session_max_length = max_length
+                session_max_length += max_length
             else:
-                session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
+                session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
             context_manager = self.inference_session(max_length=session_max_length)
 
         with context_manager as session:

+ 3 - 2
src/petals/models/bloom/model.py

@@ -71,7 +71,8 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
+        use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
+        if use_prompts:
             batch_size = inputs_embeds.shape[0]
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
@@ -88,7 +89,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
         )
 
         # Remove prefix
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+        if use_prompts:
             hidden_states = hidden_states[:, self.pre_seq_len :]
 
         # Add last hidden state

+ 3 - 2
src/petals/models/falcon/model.py

@@ -77,7 +77,8 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
+        use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
+        if use_prompts:
             batch_size = inputs_embeds.shape[0]
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
@@ -94,7 +95,7 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
         )
 
         # Remove prefix
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+        if use_prompts:
             hidden_states = hidden_states[:, self.pre_seq_len :]
 
         # Add last hidden state

+ 3 - 2
src/petals/models/llama/model.py

@@ -73,7 +73,8 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         if inputs_embeds is None:
             inputs_embeds = self.embed_tokens(input_ids)
 
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0:
+        use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0
+        if use_prompts:
             batch_size = inputs_embeds.shape[0]
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
@@ -90,7 +91,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         )
 
         # Remove prefix
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+        if use_prompts:
             hidden_states = hidden_states[:, self.pre_seq_len :]
 
         # Add last hidden state