ソースを参照

Merge branch 'main' into remove-remote-block

Pavel Samygin 3 年 前
コミット
81df31ef27
1 ファイル変更4 行追加0 行削除
  1. 4 0
      src/client/remote_generation.py

+ 4 - 0
src/client/remote_generation.py

@@ -63,6 +63,7 @@ class RemoteGenerationMixin:
         if inputs is not None:
             assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
         prefix_length = 0 if inputs is None else inputs.size(1)
+        prefix_length += self.config.pre_seq_len
 
         bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
@@ -104,6 +105,9 @@ class RemoteGenerationMixin:
             hypo_ids = torch.arange(outputs[0].size(0))
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
+                if self.config.pre_seq_len > 0 and len(outputs) == 1:
+                    prompts, _ = self.transformer.get_prompt(embs.size(0))
+                    embs = torch.cat([prompts, embs], dim=1)
                 embs = self.transformer.word_embeddings_layernorm(embs)
                 hidden_state = sess.step(embs)[:, -1]
                 hidden_state = self.transformer.ln_f(hidden_state)