Browse Source

Add prefix-tuned inference

Artem Chumachenko 3 years ago
parent
commit
e0b9239cff
1 changed files with 3 additions and 0 deletions
  1. 3 0
      src/client/remote_generation.py

+ 3 - 0
src/client/remote_generation.py

@@ -98,6 +98,9 @@ class RemoteGenerationMixin:
             hypo_ids = torch.arange(outputs[0].size(0))
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
+                if self.config.pre_seq_len > 0 and len(outputs) == 1:
+                    prompts, _ = self.transformer.get_prompt(embs.size(0))
+                    embs = torch.cat([prompts, embs], dim=1)
                 embs = self.transformer.word_embeddings_layernorm(embs)
                 hidden_state = sess.step(embs)[:, -1]
                 hidden_state = self.transformer.ln_f(hidden_state)