|
@@ -1,11 +1,11 @@
|
|
# this code is in active development, interfaces may change
|
|
# this code is in active development, interfaces may change
|
|
-import os
|
|
|
|
-from typing import List, Optional, Tuple, Union
|
|
|
|
|
|
+from typing import List, Optional, Tuple
|
|
|
|
|
|
import hivemind
|
|
import hivemind
|
|
import torch
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
from hivemind import get_logger, use_hivemind_log_handler
|
|
from hivemind import get_logger, use_hivemind_log_handler
|
|
|
|
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
|
|
|
|
|
from src.bloom.model import (
|
|
from src.bloom.model import (
|
|
BloomConfig,
|
|
BloomConfig,
|
|
@@ -17,8 +17,6 @@ from src.bloom.model import (
|
|
)
|
|
)
|
|
from src.client.remote_generation import RemoteGenerationMixin
|
|
from src.client.remote_generation import RemoteGenerationMixin
|
|
from src.client.remote_sequential import RemoteSequential
|
|
from src.client.remote_sequential import RemoteSequential
|
|
-from src.utils.generation_algorithms import DecodingAlgorithm
|
|
|
|
-from src.utils.generation_constraints import ABCBloomConstraint
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
use_hivemind_log_handler("in_root_logger")
|
|
logger = get_logger(__file__)
|
|
logger = get_logger(__file__)
|
|
@@ -34,7 +32,7 @@ class DistributedBloomConfig(BloomConfig):
|
|
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
|
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
|
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
|
|
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
|
|
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
|
|
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
|
|
- num_prefix_tokens: int = 0 # a number of tokens for prompt tuning.
|
|
|
|
|
|
+ pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
|
|
|
|
|
|
|
|
|
class DistributedBloomModel(BloomModel):
|
|
class DistributedBloomModel(BloomModel):
|
|
@@ -66,13 +64,46 @@ class DistributedBloomModel(BloomModel):
|
|
for p in self.parameters():
|
|
for p in self.parameters():
|
|
p.requires_grad = value
|
|
p.requires_grad = value
|
|
|
|
|
|
- def forward(self, *args, use_cache=None, **kwargs):
|
|
|
|
- if use_cache:
|
|
|
|
- raise ValueError(
|
|
|
|
- "Distributed forward does not support use_cache; for efficient cache-aware generation, "
|
|
|
|
- "please use model.transformer.inference_session() or model.generate(...)"
|
|
|
|
- )
|
|
|
|
- return super().forward(*args, use_cache=False, **kwargs)
|
|
|
|
|
|
+ def forward(
|
|
|
|
+ self,
|
|
|
|
+ input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
+ inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
|
+ attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
+ **kwargs,
|
|
|
|
+ ):
|
|
|
|
+ assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
|
|
|
|
+
|
|
|
|
+ for k, v in kwargs.items():
|
|
|
|
+ if not (v is None or v is False):
|
|
|
|
+ logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
|
|
|
+
|
|
|
|
+ if input_ids is not None and inputs_embeds is not None:
|
|
|
|
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
+ elif input_ids is not None:
|
|
|
|
+ input_shape = input_ids.size()
|
|
|
|
+ input_ids = input_ids.view(-1, input_shape[-1])
|
|
|
|
+ elif inputs_embeds is not None:
|
|
|
|
+ input_shape = inputs_embeds.size()[:-1]
|
|
|
|
+ else:
|
|
|
|
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
+
|
|
|
|
+ if inputs_embeds is None:
|
|
|
|
+ inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
+
|
|
|
|
+ # Note: it supports only float32 or bfloat16 inputs
|
|
|
|
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
|
|
+ output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
+ hidden_states = self.h(hidden_states)
|
|
|
|
+
|
|
|
|
+ # Add last hidden state
|
|
|
|
+ hidden_states = self.ln_f(hidden_states)
|
|
|
|
+ hidden_states = hidden_states.view(output_shape)
|
|
|
|
+ return BaseModelOutputWithPastAndCrossAttentions(
|
|
|
|
+ last_hidden_state=hidden_states,
|
|
|
|
+ past_key_values=None,
|
|
|
|
+ hidden_states=None,
|
|
|
|
+ attentions=None,
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
class DistributedBloomPrefix(DistributedBloomModel):
|
|
class DistributedBloomPrefix(DistributedBloomModel):
|
|
@@ -80,11 +111,11 @@ class DistributedBloomPrefix(DistributedBloomModel):
|
|
|
|
|
|
def __init__(self, config):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
super().__init__(config)
|
|
- assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
|
|
|
|
- self.prefix_length = config.num_prefix_tokens
|
|
|
|
|
|
+ assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
|
|
|
+ self.pre_seq_len = config.pre_seq_len
|
|
|
|
|
|
- self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
|
|
|
|
- self.prefix_tokens = torch.arange(self.prefix_length).long()
|
|
|
|
|
|
+ self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
|
|
|
|
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
|
|
|
|
|
def get_prompt(self, batch_size):
|
|
def get_prompt(self, batch_size):
|
|
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
|
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
|
@@ -94,16 +125,10 @@ class DistributedBloomPrefix(DistributedBloomModel):
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
self,
|
|
self,
|
|
- input_ids: Optional[torch.LongTensor],
|
|
|
|
- inputs_embeds: Optional[torch.Tensor],
|
|
|
|
- attention_mask: Optional[torch.Tensor],
|
|
|
|
- past_key_values=None,
|
|
|
|
- position_ids=None,
|
|
|
|
- head_mask=None,
|
|
|
|
- use_cache=None,
|
|
|
|
- output_attentions=None,
|
|
|
|
- output_hidden_states=None,
|
|
|
|
- return_dict=None,
|
|
|
|
|
|
+ input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
+ inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
|
+ attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
+ **kwargs,
|
|
):
|
|
):
|
|
assert (
|
|
assert (
|
|
input_ids is None or inputs_embeds is None
|
|
input_ids is None or inputs_embeds is None
|
|
@@ -122,17 +147,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
|
|
prompts = self.get_prompt(batch_size)
|
|
prompts = self.get_prompt(batch_size)
|
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
|
|
|
|
- transformer_outputs = super().forward(
|
|
|
|
- inputs_embeds=inputs_embeds,
|
|
|
|
- attention_mask=attention_mask,
|
|
|
|
- past_key_values=past_key_values,
|
|
|
|
- position_ids=position_ids,
|
|
|
|
- head_mask=head_mask,
|
|
|
|
- use_cache=use_cache,
|
|
|
|
- output_attentions=output_attentions,
|
|
|
|
- output_hidden_states=output_hidden_states,
|
|
|
|
- return_dict=return_dict,
|
|
|
|
- )
|
|
|
|
|
|
+ transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
|
|
|
|
|
|
# Remove prefix
|
|
# Remove prefix
|
|
last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
|
|
last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
|
|
@@ -140,14 +155,14 @@ class DistributedBloomPrefix(DistributedBloomModel):
|
|
return transformer_outputs
|
|
return transformer_outputs
|
|
|
|
|
|
|
|
|
|
-class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
|
|
|
|
|
|
+class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
|
|
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
BloomPreTrainedModel.__init__(self, config)
|
|
BloomPreTrainedModel.__init__(self, config)
|
|
- if config.num_prefix_tokens > 0:
|
|
|
|
|
|
+ if config.pre_seq_len > 0:
|
|
self.transformer = DistributedBloomPrefix(config)
|
|
self.transformer = DistributedBloomPrefix(config)
|
|
else:
|
|
else:
|
|
self.transformer = DistributedBloomModel(config)
|
|
self.transformer = DistributedBloomModel(config)
|
|
@@ -174,40 +189,13 @@ class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
|
|
self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
|
|
self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
|
|
self.lm_head.bias[...] = new_lm_head.bias
|
|
self.lm_head.bias[...] = new_lm_head.bias
|
|
|
|
|
|
- def generate(
|
|
|
|
- self,
|
|
|
|
- inputs: Optional[torch.Tensor] = None,
|
|
|
|
- do_sample: Optional[bool] = None,
|
|
|
|
- temperature: float = 1.0,
|
|
|
|
- top_k: Optional[int] = None,
|
|
|
|
- top_p: Optional[float] = None,
|
|
|
|
- eos_token_id: Optional[int] = None,
|
|
|
|
- max_new_tokens: Optional[int] = None,
|
|
|
|
- decoding_algorithm: Optional[DecodingAlgorithm] = None,
|
|
|
|
- provided_constraints: List[ABCBloomConstraint] = [],
|
|
|
|
- **model_kwargs,
|
|
|
|
- ) -> torch.Tensor:
|
|
|
|
- return RemoteGenerationMixin.generate(
|
|
|
|
- self,
|
|
|
|
- inputs=inputs,
|
|
|
|
- do_sample=do_sample,
|
|
|
|
- temperature=temperature,
|
|
|
|
- top_k=top_k,
|
|
|
|
- top_p=top_p,
|
|
|
|
- eos_token_id=eos_token_id,
|
|
|
|
- max_new_tokens=max_new_tokens,
|
|
|
|
- decoding_algorithm=decoding_algorithm,
|
|
|
|
- provided_constraints=provided_constraints,
|
|
|
|
- **model_kwargs,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
|
|
|
|
class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
|
|
class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
|
|
config_class = DistributedBloomConfig
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
super().__init__(config)
|
|
super().__init__(config)
|
|
- if config.num_prefix_tokens > 0:
|
|
|
|
|
|
+ if config.pre_seq_len > 0:
|
|
self.transformer = DistributedBloomPrefix(config)
|
|
self.transformer = DistributedBloomPrefix(config)
|
|
else:
|
|
else:
|
|
self.transformer = DistributedBloomModel(config)
|
|
self.transformer = DistributedBloomModel(config)
|