|
@@ -40,7 +40,7 @@ class RemoteGenerationMixin:
|
|
|
|
|
|
return self.transformer.h.inference_session(**kwargs)
|
|
|
|
|
|
- @torch.no_grad()
|
|
|
+ @torch.inference_mode()
|
|
|
def generate(
|
|
|
self,
|
|
|
inputs: Optional[torch.Tensor] = None,
|
|
@@ -171,13 +171,15 @@ class RemoteGenerationMixin:
|
|
|
seq_idx = outputs[0].size(1)
|
|
|
hypo_ids = torch.arange(outputs[0].size(0))
|
|
|
while True:
|
|
|
- embs = self.transformer.word_embeddings(outputs[-1])
|
|
|
+ hidden_state = self.transformer.word_embeddings(outputs[-1])
|
|
|
intermediate_prompts = None
|
|
|
if self.config.pre_seq_len > 0 and len(outputs) == 1:
|
|
|
- prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
|
|
|
- embs = torch.cat([prompts, embs], dim=1)
|
|
|
- embs = self.transformer.word_embeddings_layernorm(embs)
|
|
|
- hidden_state = session.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
|
|
|
+ prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0))
|
|
|
+ hidden_state = torch.cat([prompts, hidden_state], dim=1)
|
|
|
+ hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
|
|
|
+
|
|
|
+ hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
|
|
|
+
|
|
|
hidden_state = self.transformer.ln_f(hidden_state)
|
|
|
lm_logits = self.lm_head(hidden_state)
|
|
|
|