artek0chumak 2 years ago
parent
commit
9b2f4f6791

+ 2 - 2
src/petals/bloom/block.py

@@ -28,9 +28,9 @@ class WrappedBloomBlock(BloomBlock):
         past_length = 0 if layer_past is None else layer_past[0].shape[-1]
         if alibi is None:
             alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
-        causal_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
+        attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
         return super().forward(
-            hidden_states, *args, attention_mask=causal_mask, alibi=alibi, layer_past=layer_past, **kwargs
+            hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
         )
 
     def _prepare_attn_mask(

+ 3 - 3
src/petals/client/remote_model.py

@@ -189,11 +189,11 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
 
-        if attention_mask is None:
-            attention_mask = torch.ones((batch_size, input_shape[-1]), device=hidden_states.device)
-
         hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)
+        
+        if attention_mask is None:
+            attention_mask = torch.ones((batch_size, hidden_states.size(1)), device=hidden_states.device)
 
         if attention_mask is None:
             attention_mask = torch.ones((batch_size, hidden_states.size(1)), device=hidden_states.device)

+ 1 - 1
src/petals/client/sequential_autograd.py

@@ -302,7 +302,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
             )
         )
         grad_input_batches = [output[0][0] for output in outputs]
-        grad_prompt_batches = [output[1] for output in outputs]
+        grad_prompt_batches = [output[2] for output in outputs]
 
         grad_inputs = torch.cat(grad_input_batches, dim=0)
         dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]