|
@@ -61,7 +61,7 @@ class RemoteGenerationMixin:
|
|
|
model_kwargs.get("stopping_criteria", None) is None
|
|
|
), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
|
|
|
if inputs is not None:
|
|
|
- assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, "inputs must be a 3d tensor [batch, len, hid]"
|
|
|
+ assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
|
|
prefix_length = 0 if inputs is None else inputs.size(1)
|
|
|
|
|
|
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|