Bladeren bron

Fix ptune with `low_cpu_mem_usage=True` (as in Colab) (#103)

Fixes:

- An exception while creating a model with `ptune/deep_ptune` and `low_cpu_mem_usage=True` (which is currently default).
- dtype mismatch between the prompts and the rest of the model in `.forward()`.
Alexander Borzunov 2 jaren geleden
bovenliggende
commit
0a1cd3b9ba
1 gewijzigde bestanden met toevoegingen van 55 en 10 verwijderingen
  1. 55 10
      src/petals/client/remote_model.py

+ 55 - 10
src/petals/client/remote_model.py

@@ -1,4 +1,4 @@
-# this code is in active development, interfaces may change
+from contextlib import contextmanager
 from typing import List, Optional
 
 import hivemind
@@ -38,9 +38,35 @@ class DistributedBloomConfig(BloomConfig):
     tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
 
 
+original_register_parameter = nn.Module.register_parameter
+
+
+@contextmanager
+def force_non_empty_weights():
+    """
+    This context manager allows to bypass the accelerate.init_empty_weights() context manager
+    (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
+    The transformers library should replace all meta tensors by empty tensors by itself
+    but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
+
+    [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
+    """
+
+    try:
+        possibly_patched_register_parameter = nn.Module.register_parameter
+        nn.Module.register_parameter = original_register_parameter
+        yield
+    finally:
+        nn.Module.register_parameter = possibly_patched_register_parameter
+
+
 class DistributedBloomModel(BloomModel):
     """BloomModel, but all transformer layers are hosted by the swarm"""
 
+    _keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
+        r"^(intermediate_)?prompt_embeddings\.weight$",
+    ]
+
     config_class = DistributedBloomConfig
 
     def __init__(self, config: DistributedBloomConfig):
@@ -66,16 +92,22 @@ class DistributedBloomModel(BloomModel):
         if config.tuning_mode and "ptune" in config.tuning_mode:
             assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
             self.pre_seq_len = config.pre_seq_len
-            self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
             self.prefix_tokens = torch.arange(self.pre_seq_len).long()
 
-            if config.tuning_mode == "deep_ptune":
-                self.intermediate_prompt_embeddings = nn.Embedding(
-                    self.pre_seq_len,
-                    config.num_hidden_layers * config.hidden_size
-                    # ^-- TODO: should be num_hidden_layers - 1
-                )
-                self.intermediate_prompt_embeddings.weight.data.zero_()
+            with force_non_empty_weights():
+                if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
+                    logger.info(
+                        "Prompt embeddings and their optimizer statistics will be kept in float32 "
+                        "to increase ptune quality"
+                    )
+                self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
+                if config.tuning_mode == "deep_ptune":
+                    self.intermediate_prompt_embeddings = nn.Embedding(
+                        self.pre_seq_len,
+                        config.num_hidden_layers * config.hidden_size,
+                        # ^-- TODO: should be num_hidden_layers - 1
+                        dtype=torch.float32,
+                    )
         elif config.tuning_mode:
             raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
 
@@ -96,7 +128,9 @@ class DistributedBloomModel(BloomModel):
             intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
         else:
             intermediate_prompts = DUMMY
-        return prompts, intermediate_prompts
+
+        dtype = self.word_embeddings.weight.dtype
+        return prompts.to(dtype), intermediate_prompts.to(dtype)
 
     def forward(
         self,
@@ -155,6 +189,12 @@ class DistributedBloomModel(BloomModel):
 class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
 
+    _keys_to_ignore_on_load_missing = (
+        BloomForCausalLM._keys_to_ignore_on_load_missing
+        + DistributedBloomModel._keys_to_ignore_on_load_missing
+        + [r"^lm_head.word_embeddings\.weight$"]  # Missing since they are shared with input embeddings
+    )
+
     config_class = DistributedBloomConfig
 
     def __init__(self, config: DistributedBloomConfig):
@@ -185,6 +225,11 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
 
 
 class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
+    _keys_to_ignore_on_load_missing = (
+        BloomForSequenceClassification._keys_to_ignore_on_load_missing
+        + DistributedBloomModel._keys_to_ignore_on_load_missing
+    )
+
     config_class = DistributedBloomConfig
 
     def __init__(self, config: DistributedBloomConfig):