|
@@ -17,8 +17,6 @@ from src.bloom.model import (
|
|
|
)
|
|
|
from src.client.remote_generation import RemoteGenerationMixin
|
|
|
from src.client.remote_sequential import RemoteSequential
|
|
|
-from src.utils.generation_algorithms import DecodingAlgorithm
|
|
|
-from src.utils.generation_constraints import ABCBloomConstraint
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
@@ -156,7 +154,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
|
|
|
return transformer_outputs
|
|
|
|
|
|
|
|
|
-class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
|
|
|
+class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
|
|
|
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
@@ -190,33 +188,6 @@ class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
|
|
|
self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
|
|
|
self.lm_head.bias[...] = new_lm_head.bias
|
|
|
|
|
|
- def generate(
|
|
|
- self,
|
|
|
- inputs: Optional[torch.Tensor] = None,
|
|
|
- do_sample: Optional[bool] = None,
|
|
|
- temperature: float = 1.0,
|
|
|
- top_k: Optional[int] = None,
|
|
|
- top_p: Optional[float] = None,
|
|
|
- eos_token_id: Optional[int] = None,
|
|
|
- max_new_tokens: Optional[int] = None,
|
|
|
- decoding_algorithm: Optional[DecodingAlgorithm] = None,
|
|
|
- provided_constraints: List[ABCBloomConstraint] = [],
|
|
|
- **model_kwargs,
|
|
|
- ) -> torch.Tensor:
|
|
|
- return RemoteGenerationMixin.generate(
|
|
|
- self,
|
|
|
- inputs=inputs,
|
|
|
- do_sample=do_sample,
|
|
|
- temperature=temperature,
|
|
|
- top_k=top_k,
|
|
|
- top_p=top_p,
|
|
|
- eos_token_id=eos_token_id,
|
|
|
- max_new_tokens=max_new_tokens,
|
|
|
- decoding_algorithm=decoding_algorithm,
|
|
|
- provided_constraints=provided_constraints,
|
|
|
- **model_kwargs,
|
|
|
- )
|
|
|
-
|
|
|
|
|
|
class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
|
|
|
config_class = DistributedBloomConfig
|