|
@@ -1,4 +1,4 @@
|
|
|
-# this code is in active development, interfaces may change
|
|
|
+from contextlib import contextmanager
|
|
|
from typing import List, Optional
|
|
|
|
|
|
import hivemind
|
|
@@ -38,9 +38,35 @@ class DistributedBloomConfig(BloomConfig):
|
|
|
tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
|
|
|
|
|
|
|
|
|
+original_register_parameter = nn.Module.register_parameter
|
|
|
+
|
|
|
+
|
|
|
+@contextmanager
|
|
|
+def force_non_empty_weights():
|
|
|
+ """
|
|
|
+ This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
|
|
+ (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
|
|
+ The transformers library should replace all meta tensors by empty tensors by itself
|
|
|
+ but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
|
|
+
|
|
|
+ [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
|
|
+ """
|
|
|
+
|
|
|
+ try:
|
|
|
+ possibly_patched_register_parameter = nn.Module.register_parameter
|
|
|
+ nn.Module.register_parameter = original_register_parameter
|
|
|
+ yield
|
|
|
+ finally:
|
|
|
+ nn.Module.register_parameter = possibly_patched_register_parameter
|
|
|
+
|
|
|
+
|
|
|
class DistributedBloomModel(BloomModel):
|
|
|
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
|
+ _keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
|
|
|
+ r"^(intermediate_)?prompt_embeddings\.weight$",
|
|
|
+ ]
|
|
|
+
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
@@ -66,16 +92,22 @@ class DistributedBloomModel(BloomModel):
|
|
|
if config.tuning_mode and "ptune" in config.tuning_mode:
|
|
|
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
|
|
self.pre_seq_len = config.pre_seq_len
|
|
|
- self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
|
|
|
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
|
|
|
|
|
- if config.tuning_mode == "deep_ptune":
|
|
|
- self.intermediate_prompt_embeddings = nn.Embedding(
|
|
|
- self.pre_seq_len,
|
|
|
- config.num_hidden_layers * config.hidden_size
|
|
|
- # ^-- TODO: should be num_hidden_layers - 1
|
|
|
- )
|
|
|
- self.intermediate_prompt_embeddings.weight.data.zero_()
|
|
|
+ with force_non_empty_weights():
|
|
|
+ if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
|
|
|
+ logger.info(
|
|
|
+ "Prompt embeddings and their optimizer statistics will be kept in float32 "
|
|
|
+ "to increase ptune quality"
|
|
|
+ )
|
|
|
+ self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
|
|
+ if config.tuning_mode == "deep_ptune":
|
|
|
+ self.intermediate_prompt_embeddings = nn.Embedding(
|
|
|
+ self.pre_seq_len,
|
|
|
+ config.num_hidden_layers * config.hidden_size,
|
|
|
+ # ^-- TODO: should be num_hidden_layers - 1
|
|
|
+ dtype=torch.float32,
|
|
|
+ )
|
|
|
elif config.tuning_mode:
|
|
|
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
|
|
|
|
@@ -96,7 +128,9 @@ class DistributedBloomModel(BloomModel):
|
|
|
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
|
|
else:
|
|
|
intermediate_prompts = DUMMY
|
|
|
- return prompts, intermediate_prompts
|
|
|
+
|
|
|
+ dtype = self.word_embeddings.weight.dtype
|
|
|
+ return prompts.to(dtype), intermediate_prompts.to(dtype)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
@@ -155,6 +189,12 @@ class DistributedBloomModel(BloomModel):
|
|
|
class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
|
|
|
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
|
+ _keys_to_ignore_on_load_missing = (
|
|
|
+ BloomForCausalLM._keys_to_ignore_on_load_missing
|
|
|
+ + DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
|
+ + [r"^lm_head.word_embeddings\.weight$"] # Missing since they are shared with input embeddings
|
|
|
+ )
|
|
|
+
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
@@ -185,6 +225,11 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
|
|
|
|
|
|
|
|
|
class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
|
|
|
+ _keys_to_ignore_on_load_missing = (
|
|
|
+ BloomForSequenceClassification._keys_to_ignore_on_load_missing
|
|
|
+ + DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
|
+ )
|
|
|
+
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|