|
@@ -1,5 +1,5 @@
|
|
|
# this code is in active development, interfaces may change
|
|
|
-from typing import List, Optional, Tuple
|
|
|
+from typing import Optional, Tuple
|
|
|
|
|
|
import hivemind
|
|
|
import torch
|
|
@@ -34,9 +34,7 @@ class DistributedBloomConfig(BloomConfig):
|
|
|
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
|
|
|
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
|
|
|
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
|
|
- tuning_mode: Optional[
|
|
|
- str
|
|
|
- ] = None # One of the available options for fine-tuning: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
|
|
|
+ tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
|
|
|
|
|
|
|
|
|
class DistributedBloomModel(BloomModel):
|
|
@@ -64,8 +62,7 @@ class DistributedBloomModel(BloomModel):
|
|
|
# Forbid accumulate grads for embeddings and layernorm
|
|
|
self.set_requires_grad(False)
|
|
|
|
|
|
- self.tuning_mode = config.tuning_mode
|
|
|
- if self.tuning_mode and "ptune" in config.tuning_mode:
|
|
|
+ if config.tuning_mode and "ptune" in config.tuning_mode:
|
|
|
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
|
|
self.pre_seq_len = config.pre_seq_len
|
|
|
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
|
|
@@ -73,11 +70,13 @@ class DistributedBloomModel(BloomModel):
|
|
|
|
|
|
if config.tuning_mode == "deep_ptune":
|
|
|
self.intermediate_prompt_embeddings = nn.Embedding(
|
|
|
- self.pre_seq_len, (config.num_hidden_layers - 1) * config.hidden_size
|
|
|
+ self.pre_seq_len,
|
|
|
+ config.num_hidden_layers * config.hidden_size
|
|
|
+ # ^-- TODO: should be num_hidden_layers - 1
|
|
|
)
|
|
|
self.intermediate_prompt_embeddings.weight.data.zero_()
|
|
|
- elif self.tuning_mode:
|
|
|
- raise NotImplementedError(f"TODO: {self.tuning_mode}")
|
|
|
+ elif config.tuning_mode:
|
|
|
+ raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
|
|
|
|
|
def set_requires_grad(self, value):
|
|
|
for p in self.parameters():
|
|
@@ -88,10 +87,10 @@ class DistributedBloomModel(BloomModel):
|
|
|
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
|
|
prompts = self.prompt_embeddings(prefix_tokens)
|
|
|
|
|
|
- if self.tuning_mode == "deep_ptune":
|
|
|
+ if self.config.tuning_mode == "deep_ptune":
|
|
|
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
|
|
intermediate_prompts = intermediate_prompts.view(
|
|
|
- batch_size, self.pre_seq_len, len(self.h) - 1, self.hidden_size
|
|
|
+ batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1
|
|
|
)
|
|
|
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
|
|
else:
|
|
@@ -124,7 +123,7 @@ class DistributedBloomModel(BloomModel):
|
|
|
if inputs_embeds is None:
|
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
|
|
- if self.tuning_mode and "ptune" in self.tuning_mode:
|
|
|
+ if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
|
batch_size = inputs_embeds.shape[0]
|
|
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
|
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
@@ -132,13 +131,13 @@ class DistributedBloomModel(BloomModel):
|
|
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
|
|
- if "ptune" in self.tuning_mode:
|
|
|
+ if "ptune" in self.config.tuning_mode:
|
|
|
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
|
|
|
else:
|
|
|
hidden_states = self.h(hidden_states)
|
|
|
|
|
|
# Remove prefix
|
|
|
- if self.tuning_mode and "ptune" in self.tuning_mode:
|
|
|
+ if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
|
|
|
|
|
# Add last hidden state
|
|
@@ -159,10 +158,7 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
BloomPreTrainedModel.__init__(self, config)
|
|
|
- if config.pre_seq_len > 0:
|
|
|
- self.transformer = DistributedBloomPrefix(config)
|
|
|
- else:
|
|
|
- self.transformer = DistributedBloomModel(config)
|
|
|
+ self.transformer = DistributedBloomModel(config)
|
|
|
self.lm_head = LMHead(config, self.transformer.word_embeddings)
|
|
|
|
|
|
# Initialize weights and apply final processing
|
|
@@ -192,10 +188,7 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
super().__init__(config)
|
|
|
- if config.pre_seq_len > 0:
|
|
|
- self.transformer = DistributedBloomPrefix(config)
|
|
|
- else:
|
|
|
- self.transformer = DistributedBloomModel(config)
|
|
|
+ self.transformer = DistributedBloomModel(config)
|
|
|
|
|
|
# Initialize weights and apply final processing
|
|
|
self.post_init()
|