|
@@ -2,15 +2,20 @@
|
|
|
import os
|
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
+import hivemind
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
-
|
|
|
-import hivemind
|
|
|
from hivemind import get_logger, use_hivemind_log_handler
|
|
|
|
|
|
-from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead, BloomForSequenceClassification
|
|
|
+from src.bloom.model import (
|
|
|
+ BloomConfig,
|
|
|
+ BloomForCausalLM,
|
|
|
+ BloomForSequenceClassification,
|
|
|
+ BloomModel,
|
|
|
+ BloomPreTrainedModel,
|
|
|
+ LMHead,
|
|
|
+)
|
|
|
from src.client.remote_sequential import RemoteSequential
|
|
|
-from src.data_structures import UID_DELIMITER
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
@@ -25,12 +30,13 @@ class DistributedBloomConfig(BloomConfig):
|
|
|
initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
|
|
|
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
|
|
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
|
|
|
- chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
|
|
|
- num_prefix_tokens: int = 0 # a number of tokens for prompt tuning.
|
|
|
+ chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
|
|
|
+ num_prefix_tokens: int = 0 # a number of tokens for prompt tuning.
|
|
|
|
|
|
|
|
|
class DistributedBloomModel(BloomModel):
|
|
|
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
|
|
+
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
@@ -49,7 +55,7 @@ class DistributedBloomModel(BloomModel):
|
|
|
)
|
|
|
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
|
|
|
self.h = RemoteSequential(config, dht, config.dht_prefix)
|
|
|
-
|
|
|
+
|
|
|
# Forbid accumulate grads for embeddings and layernorm
|
|
|
self.set_requires_grad(False)
|
|
|
|
|
@@ -57,6 +63,14 @@ class DistributedBloomModel(BloomModel):
|
|
|
for p in self.parameters():
|
|
|
p.requires_grad = value
|
|
|
|
|
|
+ def forward(self, *args, use_cache=None, **kwargs):
|
|
|
+ if use_cache:
|
|
|
+ raise ValueError(
|
|
|
+ "Distributed forward does not support use_cache; for efficient cache-aware generation, "
|
|
|
+ "please use model.transformer.inference_session() or model.generate(...)"
|
|
|
+ )
|
|
|
+ return super().forward(*args, use_cache=False, **kwargs)
|
|
|
+
|
|
|
|
|
|
class DistributedBloomPrefix(DistributedBloomModel):
|
|
|
"""DistributedBloomModel with prefix tokens for prompt tuning"""
|
|
@@ -76,7 +90,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
|
|
|
return prompts
|
|
|
|
|
|
def forward(
|
|
|
- self,
|
|
|
+ self,
|
|
|
input_ids: Optional[torch.LongTensor],
|
|
|
inputs_embeds: Optional[torch.Tensor],
|
|
|
attention_mask: Optional[torch.Tensor],
|
|
@@ -86,14 +100,16 @@ class DistributedBloomPrefix(DistributedBloomModel):
|
|
|
use_cache=None,
|
|
|
output_attentions=None,
|
|
|
output_hidden_states=None,
|
|
|
- return_dict=None
|
|
|
+ return_dict=None,
|
|
|
):
|
|
|
- assert input_ids is None or inputs_embeds is None, "You cannot specify both input_ids and inputs_embeds at the same time"
|
|
|
+ assert (
|
|
|
+ input_ids is None or inputs_embeds is None
|
|
|
+ ), "You cannot specify both input_ids and inputs_embeds at the same time"
|
|
|
assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
|
|
|
-
|
|
|
+
|
|
|
if inputs_embeds is None:
|
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
-
|
|
|
+
|
|
|
batch_size = inputs_embeds.shape[0]
|
|
|
|
|
|
if attention_mask is not None:
|
|
@@ -104,25 +120,26 @@ class DistributedBloomPrefix(DistributedBloomModel):
|
|
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
|
|
|
|
transformer_outputs = super().forward(
|
|
|
- inputs_embeds=inputs_embeds,
|
|
|
- attention_mask=attention_mask,
|
|
|
+ inputs_embeds=inputs_embeds,
|
|
|
+ attention_mask=attention_mask,
|
|
|
past_key_values=past_key_values,
|
|
|
position_ids=position_ids,
|
|
|
head_mask=head_mask,
|
|
|
use_cache=use_cache,
|
|
|
output_attentions=output_attentions,
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
- return_dict=return_dict
|
|
|
+ return_dict=return_dict,
|
|
|
)
|
|
|
|
|
|
# Remove prefix
|
|
|
- last_hidden_state = transformer_outputs[0][:, self.prefix_length:]
|
|
|
- transformer_outputs['last_hidden_state'] = last_hidden_state
|
|
|
+ last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
|
|
|
+ transformer_outputs["last_hidden_state"] = last_hidden_state
|
|
|
return transformer_outputs
|
|
|
|
|
|
|
|
|
class DistributedBloomForCausalLM(BloomForCausalLM):
|
|
|
- """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
|
+ """Similar to BloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
|
+
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
@@ -136,11 +153,23 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
|
|
|
# Initialize weights and apply final processing
|
|
|
self.post_init()
|
|
|
|
|
|
- def get_output_embeddings(self):
|
|
|
- return self.lm_head.word_embeddings
|
|
|
+ def get_input_embeddings(self):
|
|
|
+ return self.transformer.word_embeddings
|
|
|
|
|
|
- def set_output_embeddings(self, new_embeddings):
|
|
|
- self.lm_head.word_embeddings.weight = new_embeddings.weight
|
|
|
+ def get_output_embeddings(self):
|
|
|
+ if self.config.tie_word_embeddings:
|
|
|
+ return None
|
|
|
+ return self.lm_head
|
|
|
+
|
|
|
+ def set_input_embeddings(self, new_embeddings: nn.Embedding):
|
|
|
+ assert isinstance(new_embeddings, nn.Embedding)
|
|
|
+ self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
|
|
|
+ assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
|
|
|
+
|
|
|
+ def set_output_embeddings(self, new_lm_head: nn.Linear):
|
|
|
+ with torch.no_grad():
|
|
|
+ self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
|
|
|
+ self.lm_head.bias[...] = new_lm_head.bias
|
|
|
|
|
|
|
|
|
class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
|