浏览代码

Raise error for unexpected .generate() kwargs (#315)

Now, if a user passes unexpected kwargs to `.generate()`, they are __ignored__ and the code continues working as if the argument was correctly supported. For example, people often tried passing `repetition_penalty` and didn't notice that it does not have any effect. This PR fixes this problem.
Alexander Borzunov 2 年之前
父节点
当前提交
6eb306a605
共有 1 个文件被更改,包括 1 次插入22 次删除
  1. 1 22
      src/petals/client/remote_generation.py

+ 1 - 22
src/petals/client/remote_generation.py

@@ -44,6 +44,7 @@ class RemoteGenerationMixin:
     def generate(
         self,
         inputs: Optional[torch.Tensor] = None,
+        *,
         do_sample: Optional[bool] = None,
         temperature: float = 1.0,
         top_k: Optional[int] = None,
@@ -57,9 +58,7 @@ class RemoteGenerationMixin:
         decoding_algorithm: Optional[DecodingAlgorithm] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
         num_return_sequences: Optional[int] = None,
-        *,
         session: Optional[InferenceSession] = None,
-        **model_kwargs,
     ) -> torch.LongTensor:
         """
         Generates sequences of token ids for models with a language modeling head.
@@ -77,19 +76,9 @@ class RemoteGenerationMixin:
         :param max_new_tokens: The maximum number of tokens to generate.
         :param decoding_algorithm: The decoding algorithm to use.
         :param provided_constraints: A list of constraints to use.
-        :param model_kwargs: Additional arguments to pass to the model.
         :param num_return_sequences: How many hypothesis from the beam will be in output.
         """
 
-        assert (
-            model_kwargs.get("logits_processor", None) is None
-        ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
-        assert (
-            model_kwargs.get("logits_wrapper", None) is None
-        ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
-        assert (
-            model_kwargs.get("stopping_criteria", None) is None
-        ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
         prefix_length = 0 if inputs is None else inputs.size(1)
         prefix_length += self.config.pre_seq_len
 
@@ -226,7 +215,6 @@ class RemoteGenerationMixin:
         pad_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
-        **model_kwargs,
     ) -> torch.LongTensor:
         """
         Generates sequences of token ids for models with a language modeling head. Uses greedy search.
@@ -244,7 +232,6 @@ class RemoteGenerationMixin:
             eos_token_id=eos_token_id,
             decoding_algorithm=GreedyAlgorithm(),
             provided_constraints=provided_constraints,
-            **model_kwargs,
         )
 
     def sample(
@@ -257,7 +244,6 @@ class RemoteGenerationMixin:
         pad_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
-        **model_kwargs,
     ) -> torch.LongTensor:
         """
         Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
@@ -271,7 +257,6 @@ class RemoteGenerationMixin:
         :param: pad_token_id: The id of the padding token.
         :param: eos_token_id: The id of the end of sentence token.
         :param: provided_constraints: A list of constraints to use.
-        :param: model_kwargs: Additional kwargs to pass to the model.
         """
 
         return self.generate(
@@ -281,7 +266,6 @@ class RemoteGenerationMixin:
             eos_token_id=eos_token_id,
             decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
             provided_constraints=provided_constraints,
-            **model_kwargs,
         )
 
     def beam_search(
@@ -292,7 +276,6 @@ class RemoteGenerationMixin:
         pad_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
-        **model_kwargs,
     ) -> torch.LongTensor:
         """
         Generates sequences of token ids for models with a language modeling head. Uses beam search.
@@ -303,7 +286,6 @@ class RemoteGenerationMixin:
         :param pad_token_id: The id of the padding token.
         :param eos_token_id: The id of the end of sentence token.
         :param provided_constraints: A list of constraints to use.
-        :param: model_kwargs: Additional kwargs to pass to the model.
         """
         decoding_algorithm = BeamSearchAlgorithm(
             num_beams=num_beams,
@@ -317,7 +299,6 @@ class RemoteGenerationMixin:
             eos_token_id=eos_token_id,
             decoding_algorithm=decoding_algorithm,
             provided_constraints=provided_constraints,
-            **model_kwargs,
         )
 
     def beam_sample(
@@ -327,7 +308,6 @@ class RemoteGenerationMixin:
         pad_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
-        **model_kwargs,
     ) -> torch.LongTensor:
         raise NotImplementedError
 
@@ -338,7 +318,6 @@ class RemoteGenerationMixin:
         pad_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
-        **model_kwargs,
     ) -> torch.LongTensor:
         raise NotImplementedError