|
@@ -179,9 +179,8 @@ class RemoteGenerationMixin:
|
|
|
hidden_state = torch.cat([prompts, hidden_state], dim=1)
|
|
|
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
|
|
|
|
|
|
- attention_mask = torch.ones((batch_size, seq_idx), device=hidden_state.device)
|
|
|
hidden_state = session.step(
|
|
|
- hidden_state, attention_mask, prompts=intermediate_prompts, hypo_ids=hypo_ids
|
|
|
+ hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids
|
|
|
)[:, -1]
|
|
|
|
|
|
hidden_state = self.transformer.ln_f(hidden_state)
|