5
0
Эх сурвалжийг харах

Add preseq_length in prefix size

Artem Chumachenko 3 жил өмнө
parent
commit
3deb385865

+ 1 - 0
src/client/remote_generation.py

@@ -63,6 +63,7 @@ class RemoteGenerationMixin:
         if inputs is not None:
             assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
         prefix_length = 0 if inputs is None else inputs.size(1)
+        prefix_length += self.config.pre_seq_len
 
         bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id