|
@@ -98,6 +98,9 @@ class RemoteGenerationMixin:
|
|
|
hypo_ids = torch.arange(outputs[0].size(0))
|
|
|
while True:
|
|
|
embs = self.transformer.word_embeddings(outputs[-1])
|
|
|
+ if self.config.pre_seq_len > 0 and len(outputs) == 1:
|
|
|
+ prompts, _ = self.transformer.get_prompt(embs.size(0))
|
|
|
+ embs = torch.cat([prompts, embs], dim=1)
|
|
|
embs = self.transformer.word_embeddings_layernorm(embs)
|
|
|
hidden_state = sess.step(embs)[:, -1]
|
|
|
hidden_state = self.transformer.ln_f(hidden_state)
|