|
@@ -1,6 +1,7 @@
|
|
|
# this code is in active development, interfaces may change
|
|
|
import os
|
|
|
-from typing import Optional, Tuple
|
|
|
+import torch
|
|
|
+from typing import List, Optional, Union, Tuple
|
|
|
|
|
|
import hivemind
|
|
|
import torch
|
|
@@ -17,6 +18,8 @@ from src.bloom.model import (
|
|
|
)
|
|
|
from src.client.remote_sequential import RemoteSequential
|
|
|
from src.client.remote_generation import RemoteGenerationMixin
|
|
|
+from src.utils.generation_algorithms import DecodingAlgorithm
|
|
|
+from src.utils.generation_constraints import ABConstraint
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
@@ -171,6 +174,33 @@ class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
|
|
|
self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
|
|
|
self.lm_head.bias[...] = new_lm_head.bias
|
|
|
|
|
|
+ def generate(
|
|
|
+ self,
|
|
|
+ inputs: Optional[torch.Tensor] = None,
|
|
|
+ do_sample: Optional[bool] = None,
|
|
|
+ temperature: float = 1.0,
|
|
|
+ top_k: Optional[int] = None,
|
|
|
+ top_p: Optional[float] = None,
|
|
|
+ eos_token_id: Optional[int] = None,
|
|
|
+ max_new_tokens: Optional[int] = None,
|
|
|
+ decoding_algorithm: Optional[DecodingAlgorithm] = None,
|
|
|
+ provided_constraints: List[ABConstraint] = [],
|
|
|
+ **model_kwargs,
|
|
|
+ ) -> torch.Tensor:
|
|
|
+ return RemoteGenerationMixin.generate(
|
|
|
+ self,
|
|
|
+ inputs,
|
|
|
+ do_sample,
|
|
|
+ temperature,
|
|
|
+ top_k,
|
|
|
+ top_p,
|
|
|
+ eos_token_id,
|
|
|
+ max_new_tokens,
|
|
|
+ decoding_algorithm,
|
|
|
+ provided_constraints,
|
|
|
+ **model_kwargs,
|
|
|
+ )
|
|
|
+
|
|
|
|
|
|
class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
|
|
|
config_class = DistributedBloomConfig
|