Selaa lähdekoodia

Provide attention mask

Artem Chumachenko 2 vuotta sitten
vanhempi
commit
4e478782a4

+ 3 - 5
src/petals/bloom/block.py

@@ -18,8 +18,8 @@ class WrappedBloomBlock(BloomBlock):
     def forward(
         self,
         hidden_states: torch.Tensor,
-        *args,
         attention_mask: Optional[torch.Tensor] = None,
+        *args,
         alibi: Optional[torch.Tensor] = None,
         layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
         **kwargs
@@ -27,13 +27,11 @@ class WrappedBloomBlock(BloomBlock):
         assert attention_mask is None
         batch_size, seq_length = hidden_states.shape[:2]
         past_length = 0 if layer_past is None else layer_past[0].shape[-1]
-        seq_length_with_past = seq_length + past_length
-        attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
         if alibi is None:
             alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
-        attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
+        causal_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
         return super().forward(
-            hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
+            hidden_states, *args, attention_mask=causal_mask, alibi=alibi, layer_past=layer_past, **kwargs
         )
 
     def _prepare_attn_mask(

+ 7 - 3
src/petals/client/inference_session.py

@@ -78,6 +78,7 @@ class _ServerInferenceSession:
     def step(
         self,
         new_hidden_states: torch.Tensor,
+        attention_mask: torch.Tensor,
         prompts: Optional[torch.Tensor] = None,
         hypo_ids: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
@@ -104,7 +105,7 @@ class _ServerInferenceSession:
             assert hypo_ids.dtype == torch.int64
 
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states, prompts, hypo_ids)
+        inputs = (new_hidden_states, attention_mask, prompts, hypo_ids)
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(
@@ -212,7 +213,9 @@ class InferenceSession:
         assert not self._closed and not self._chosen_spans
         return self
 
-    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
+    def step(
+        self, inputs: torch.Tensor, attention_mask: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs
+    ) -> torch.Tensor:
         assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@@ -226,6 +229,7 @@ class InferenceSession:
         inputs_device = inputs.device
         inputs_dtype = inputs.dtype
         inputs = inputs.cpu()
+        attention_mask = attention_mask.cpu()
         prompts = prompts.cpu()
 
         n_input_tokens = inputs.shape[1]
@@ -294,7 +298,7 @@ class InferenceSession:
                     else:
                         inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
 
-                    outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
+                    outputs = session.step(inputs, attention_mask, prompts[span.start : span.end], **kwargs)
                     assert (
                         inputs.shape == outputs.shape
                     ), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"

+ 4 - 1
src/petals/client/remote_generation.py

@@ -179,7 +179,10 @@ class RemoteGenerationMixin:
                     hidden_state = torch.cat([prompts, hidden_state], dim=1)
                 hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
 
-                hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
+                attention_mask = torch.ones((batch_size, seq_idx), device=hidden_state.device)
+                hidden_state = session.step(
+                    hidden_state, attention_mask, prompts=intermediate_prompts, hypo_ids=hypo_ids
+                )[:, -1]
 
                 hidden_state = self.transformer.ln_f(hidden_state)
                 lm_logits = self.lm_head(hidden_state)

+ 5 - 4
src/petals/client/remote_model.py

@@ -167,8 +167,6 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
         attention_mask: Optional[torch.Tensor] = None,
         **kwargs,
     ):
-        assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
-
         for k, v in kwargs.items():
             if not (v is None or v is False):
                 logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
@@ -191,13 +189,16 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
 
+        if attention_mask is None:
+            attention_mask = torch.ones((batch_size, input_shape[-1]), device=hidden_states.device)
+
         hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)
 
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
-            hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
+            hidden_states = self.h(hidden_states, attention_mask, prompts=intermediate_prompts)
         else:
-            hidden_states = self.h(hidden_states)
+            hidden_states = self.h(hidden_states, attention_mask)
 
         # Remove prefix
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode:

+ 3 - 0
src/petals/server/server.py

@@ -420,6 +420,9 @@ class ModuleContainer(threading.Thread):
                         BatchTensorDescriptor(
                             1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
                         ),
+                        BatchTensorDescriptor(
+                            1, 2048, dtype=backend_dtype, compression=compression
+                        ),
                     ),
                     kwargs_schema={},
                     outputs_schema=(