瀏覽代碼

fix generate

artek0chumak 3 年之前
父節點
當前提交
2e0abe48ae
共有 2 個文件被更改,包括 32 次插入2 次删除
  1. 1 1
      src/client/remote_generation.py
  2. 31 1
      src/client/remote_model.py

+ 1 - 1
src/client/remote_generation.py

@@ -9,7 +9,7 @@ from transformers.modeling_utils import PreTrainedModel
 
 
 class RemoteGenerationMixin(PreTrainedModel):
-    def generation(
+    def generate(
         self,
         inputs: Optional[torch.Tensor] = None,
         do_sample: Optional[bool] = None,

+ 31 - 1
src/client/remote_model.py

@@ -1,6 +1,7 @@
 # this code is in active development, interfaces may change
 import os
-from typing import Optional, Tuple
+import torch
+from typing import List, Optional, Union, Tuple
 
 import hivemind
 import torch
@@ -17,6 +18,8 @@ from src.bloom.model import (
 )
 from src.client.remote_sequential import RemoteSequential
 from src.client.remote_generation import RemoteGenerationMixin
+from src.utils.generation_algorithms import DecodingAlgorithm
+from src.utils.generation_constraints import ABConstraint
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -171,6 +174,33 @@ class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
             self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
             self.lm_head.bias[...] = new_lm_head.bias
 
+    def generate(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        do_sample: Optional[bool] = None,
+        temperature: float = 1.0,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+        eos_token_id: Optional[int] = None,
+        max_new_tokens: Optional[int] = None,
+        decoding_algorithm: Optional[DecodingAlgorithm] = None,
+        provided_constraints: List[ABConstraint] = [],
+        **model_kwargs,
+    ) -> torch.Tensor:
+        return RemoteGenerationMixin.generate(
+            self,
+            inputs,
+            do_sample,
+            temperature,
+            top_k,
+            top_p,
+            eos_token_id,
+            max_new_tokens,
+            decoding_algorithm,
+            provided_constraints,
+            **model_kwargs,
+        )
+
 
 class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
     config_class = DistributedBloomConfig