|
@@ -2,10 +2,13 @@
|
|
|
import os
|
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+
|
|
|
import hivemind
|
|
|
from hivemind import get_logger, use_hivemind_log_handler
|
|
|
|
|
|
-from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead
|
|
|
+from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead, BloomForSequenceClassification
|
|
|
from src.client.remote_sequential import RemoteSequential
|
|
|
from src.data_structures import UID_DELIMITER
|
|
|
|
|
@@ -22,7 +25,8 @@ class DistributedBloomConfig(BloomConfig):
|
|
|
initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
|
|
|
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
|
|
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
|
|
|
- chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
|
|
|
+ chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
|
|
|
+ num_prefix_tokens: int = 0 # a number of tokens for prompt tuning.
|
|
|
|
|
|
|
|
|
class DistributedBloomModel(BloomModel):
|
|
@@ -55,6 +59,69 @@ class DistributedBloomModel(BloomModel):
|
|
|
p.requires_grad = value
|
|
|
|
|
|
|
|
|
+class DistributedBloomPrefix(DistributedBloomModel):
|
|
|
+ """DistributedBloomModel with prefix tokens for prompt tuning"""
|
|
|
+
|
|
|
+ def __init__(self, config):
|
|
|
+ super().__init__(config)
|
|
|
+ assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
|
|
|
+ self.prefix_length = config.num_prefix_tokens
|
|
|
+
|
|
|
+ self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
|
|
|
+ self.prefix_tokens = torch.arange(self.prefix_length).long()
|
|
|
+
|
|
|
+ def get_prompt(self, batch_size):
|
|
|
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
|
|
+ prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
|
|
+ prompts = self.prompt_embeddings(prefix_tokens)
|
|
|
+ return prompts
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ input_ids: Optional[torch.LongTensor],
|
|
|
+ inputs_embeds: Optional[torch.Tensor],
|
|
|
+ attention_mask: Optional[torch.Tensor],
|
|
|
+ past_key_values=None,
|
|
|
+ position_ids=None,
|
|
|
+ head_mask=None,
|
|
|
+ use_cache=None,
|
|
|
+ output_attentions=None,
|
|
|
+ output_hidden_states=None,
|
|
|
+ return_dict=None
|
|
|
+ ):
|
|
|
+ assert input_ids is None or inputs_embeds is None, "You cannot specify both input_ids and inputs_embeds at the same time"
|
|
|
+ assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
|
|
|
+
|
|
|
+ if inputs_embeds is None:
|
|
|
+ inputs_embeds = self.word_embeddings(input_ids)
|
|
|
+
|
|
|
+ batch_size = inputs_embeds.shape[0]
|
|
|
+
|
|
|
+ if attention_mask is not None:
|
|
|
+ prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
|
|
|
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
|
|
+
|
|
|
+ prompts = self.get_prompt(batch_size)
|
|
|
+ inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
|
+
|
|
|
+ transformer_outputs = super().forward(
|
|
|
+ inputs_embeds=inputs_embeds,
|
|
|
+ attention_mask=attention_mask,
|
|
|
+ past_key_values=past_key_values,
|
|
|
+ position_ids=position_ids,
|
|
|
+ head_mask=head_mask,
|
|
|
+ use_cache=use_cache,
|
|
|
+ output_attentions=output_attentions,
|
|
|
+ output_hidden_states=output_hidden_states,
|
|
|
+ return_dict=return_dict
|
|
|
+ )
|
|
|
+
|
|
|
+ # Remove prefix
|
|
|
+ last_hidden_state = transformer_outputs[0][:, self.prefix_length:]
|
|
|
+ transformer_outputs['last_hidden_state'] = last_hidden_state
|
|
|
+ return transformer_outputs
|
|
|
+
|
|
|
+
|
|
|
class DistributedBloomForCausalLM(BloomForCausalLM):
|
|
|
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
@@ -62,8 +129,12 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
BloomPreTrainedModel.__init__(self, config)
|
|
|
- self.transformer = DistributedBloomModel(config)
|
|
|
+ if config.num_prefix_tokens > 0:
|
|
|
+ self.transformer = DistributedBloomPrefix(config)
|
|
|
+ else:
|
|
|
+ self.transformer = DistributedBloomModel(config)
|
|
|
self.lm_head = LMHead(config, self.transformer.word_embeddings)
|
|
|
+
|
|
|
# Initialize weights and apply final processing
|
|
|
self.post_init()
|
|
|
|
|
@@ -72,3 +143,17 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
|
|
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
|
self.lm_head.word_embeddings.weight = new_embeddings.weight
|
|
|
+
|
|
|
+
|
|
|
+class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
|
|
|
+ config_class = DistributedBloomConfig
|
|
|
+
|
|
|
+ def __init__(self, config: DistributedBloomConfig):
|
|
|
+ super().__init__(config)
|
|
|
+ if config.num_prefix_tokens > 0:
|
|
|
+ self.transformer = DistributedBloomPrefix(config)
|
|
|
+ else:
|
|
|
+ self.transformer = DistributedBloomModel(config)
|
|
|
+
|
|
|
+ # Initialize weights and apply final processing
|
|
|
+ self.post_init()
|