소스 검색

Add prefix-tuned inference

Artem Chumachenko 3 년 전
부모
커밋
a440fb5fba
1개의 변경된 파일3개의 추가작업 그리고 0개의 파일을 삭제
  1. 3 0
      src/client/remote_generation.py

+ 3 - 0
src/client/remote_generation.py

@@ -104,6 +104,9 @@ class RemoteGenerationMixin:
             hypo_ids = torch.arange(outputs[0].size(0))
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
+                if self.config.pre_seq_len > 0 and len(outputs) == 1:
+                    prompts, _ = self.transformer.get_prompt(embs.size(0))
+                    embs = torch.cat([prompts, embs], dim=1)
                 embs = self.transformer.word_embeddings_layernorm(embs)
                 hidden_state = sess.step(embs)[:, -1]
                 hidden_state = self.transformer.ln_f(hidden_state)