|
@@ -68,12 +68,12 @@ class RemoteGenerationMixin:
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
|
|
|
|
|
+ assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
|
|
|
if max_length is not None and max_new_tokens is None:
|
|
|
max_new_tokens = max_length - prefix_length
|
|
|
assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
|
|
|
elif max_length is None and max_new_tokens is not None:
|
|
|
max_length = prefix_length + max_new_tokens
|
|
|
- assert max_length is not None and max_new_tokens is not None
|
|
|
|
|
|
if inputs is None:
|
|
|
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|