artek0chumak před 2 roky
rodič
revize
1387711236

+ 6 - 0
src/petals/client/inference_session.py

@@ -103,6 +103,9 @@ class _ServerInferenceSession:
         else:
             assert len(hypo_ids) == len(new_hidden_states)
             assert hypo_ids.dtype == torch.int64
+            
+        if attention_mask is None:
+            attention_mask = DUMMY
 
         if attention_mask is None:
             attention_mask = DUMMY
@@ -232,6 +235,9 @@ class InferenceSession:
             prompts = DUMMY
         else:
             assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
+            
+        if attention_mask is None:
+            attention_mask = DUMMY
 
         if attention_mask is None:
             attention_mask = DUMMY

+ 1 - 2
src/petals/client/remote_generation.py

@@ -179,9 +179,8 @@ class RemoteGenerationMixin:
                     hidden_state = torch.cat([prompts, hidden_state], dim=1)
                 hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
 
-                attention_mask = torch.ones((batch_size, seq_idx), device=hidden_state.device)
                 hidden_state = session.step(
-                    hidden_state, attention_mask, prompts=intermediate_prompts, hypo_ids=hypo_ids
+                    hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids
                 )[:, -1]
 
                 hidden_state = self.transformer.ln_f(hidden_state)

+ 1 - 1
src/petals/client/sequential_autograd.py

@@ -302,7 +302,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
             )
         )
         grad_input_batches = [output[0][0] for output in outputs]
-        grad_prompt_batches = [output[2] for output in outputs]
+        grad_prompt_batches = [output[1] for output in outputs]
 
         grad_inputs = torch.cat(grad_input_batches, dim=0)
         dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]

+ 3 - 0
src/petals/server/handler.py

@@ -158,6 +158,9 @@ class TransformerConnectionHandler(ConnectionHandler):
                                 f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
                                 f" exceeds pre-allocated maximum {max_length}"
                             )
+                            
+                        if is_dummy(attention_mask):
+                            attention_mask = torch.ones((hidden_states.shape[0], prefix_length + length_increment), dtype=hypo_ids.dtype)
 
                         if is_dummy(attention_mask):
                             attention_mask = torch.ones(