|
@@ -179,7 +179,10 @@ class RemoteGenerationMixin:
|
|
hidden_state = torch.cat([prompts, hidden_state], dim=1)
|
|
hidden_state = torch.cat([prompts, hidden_state], dim=1)
|
|
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
|
|
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
|
|
|
|
|
|
- hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
|
|
|
|
|
|
+ attention_mask = torch.ones((batch_size, seq_idx), device=hidden_state.device)
|
|
|
|
+ hidden_state = session.step(
|
|
|
|
+ hidden_state, attention_mask, prompts=intermediate_prompts, hypo_ids=hypo_ids
|
|
|
|
+ )[:, -1]
|
|
|
|
|
|
hidden_state = self.transformer.ln_f(hidden_state)
|
|
hidden_state = self.transformer.ln_f(hidden_state)
|
|
lm_logits = self.lm_head(hidden_state)
|
|
lm_logits = self.lm_head(hidden_state)
|