Artem Chumachenko преди 2 години
родител
ревизия
9de60e7895
променени са 3 файла, в които са добавени 9 реда и са изтрити 3 реда
  1. 2 2
      src/petals/bloom/block.py
  2. 4 1
      src/petals/client/remote_generation.py
  3. 3 0
      src/petals/client/remote_model.py

+ 2 - 2
src/petals/bloom/block.py

@@ -28,9 +28,9 @@ class WrappedBloomBlock(BloomBlock):
         past_length = 0 if layer_past is None else layer_past[0].shape[-1]
         if alibi is None:
             alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
-        attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
+        causal_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
         return super().forward(
-            hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
+            hidden_states, *args, attention_mask=causal_mask, alibi=alibi, layer_past=layer_past, **kwargs
         )
 
     def _prepare_attn_mask(

+ 4 - 1
src/petals/client/remote_generation.py

@@ -179,7 +179,10 @@ class RemoteGenerationMixin:
                     hidden_state = torch.cat([prompts, hidden_state], dim=1)
                 hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
 
-                hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
+                attention_mask = torch.ones((batch_size, seq_idx), device=hidden_state.device)
+                hidden_state = session.step(
+                    hidden_state, attention_mask, prompts=intermediate_prompts, hypo_ids=hypo_ids
+                )[:, -1]
 
                 hidden_state = self.transformer.ln_f(hidden_state)
                 lm_logits = self.lm_head(hidden_state)

+ 3 - 0
src/petals/client/remote_model.py

@@ -189,6 +189,9 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
 
+        if attention_mask is None:
+            attention_mask = torch.ones((batch_size, input_shape[-1]), device=hidden_states.device)
+
         hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)