|
@@ -179,9 +179,7 @@ class RemoteGenerationMixin:
|
|
|
hidden_state = torch.cat([prompts, hidden_state], dim=1)
|
|
|
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
|
|
|
|
|
|
- hidden_state = session.step(
|
|
|
- hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids
|
|
|
- )[:, -1]
|
|
|
+ hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
|
|
|
|
|
|
hidden_state = self.transformer.ln_f(hidden_state)
|
|
|
lm_logits = self.lm_head(hidden_state)
|