Browse Source

Add preseq_length in prefix size

Artem Chumachenko 3 years ago
parent
commit
3deb385865
1 changed files with 1 additions and 0 deletions
  1. 1 0
      src/client/remote_generation.py

+ 1 - 0
src/client/remote_generation.py

@@ -63,6 +63,7 @@ class RemoteGenerationMixin:
         if inputs is not None:
             assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
         prefix_length = 0 if inputs is None else inputs.size(1)
+        prefix_length += self.config.pre_seq_len
 
         bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id