|
@@ -189,11 +189,11 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
|
|
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
|
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
|
|
|
|
- if attention_mask is None:
|
|
|
- attention_mask = torch.ones((batch_size, input_shape[-1]), device=hidden_states.device)
|
|
|
-
|
|
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
+
|
|
|
+ if attention_mask is None:
|
|
|
+ attention_mask = torch.ones((batch_size, hidden_states.size(1)), device=hidden_states.device)
|
|
|
|
|
|
if attention_mask is None:
|
|
|
attention_mask = torch.ones((batch_size, hidden_states.size(1)), device=hidden_states.device)
|