Ver Fonte

Merge remote-tracking branch 'origin/main' into optimize_seq

justheuristic há 3 anos atrás
pai
commit
a5fdb5753e
4 ficheiros alterados com 13 adições e 45 exclusões
  1. 1 36
      src/bloom/model.py
  2. 4 1
      src/client/remote_model.py
  3. 0 2
      src/server/backend.py
  4. 8 6
      src/utils/convert_8bit.py

+ 1 - 36
src/bloom/model.py

@@ -23,6 +23,7 @@ from transformers.modeling_outputs import (
 )
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig
+from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
 from transformers.utils import logging
 
 from src.bloom.block import BloomBlock
@@ -35,42 +36,6 @@ _CONFIG_FOR_DOC = "BloomConfig"
 _TOKENIZER_FOR_DOC = "BloomTokenizer"
 
 
-class BloomPreTrainedModel(PreTrainedModel):
-    _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
-    """
-    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
-    models.
-    """
-
-    config_class = BloomConfig
-    base_model_prefix = "transformer"
-    supports_gradient_checkpointing = True
-    _no_split_modules = ["BloomBlock"]
-
-    def __init__(self, *inputs, **kwargs):
-        super().__init__(*inputs, **kwargs)
-
-    def _init_weights(self, module):
-        """Initialize the weights."""
-        if isinstance(module, (nn.Linear)):
-            # Slightly different from the TF version which uses truncated_normal for initialization
-            # cf https://github.com/pytorch/pytorch/pull/5617
-            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
-            if module.bias is not None:
-                module.bias.data.zero_()
-        elif isinstance(module, nn.Embedding):
-            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
-            if module.padding_idx is not None:
-                module.weight.data[module.padding_idx].zero_()
-        elif isinstance(module, LayerNorm):
-            module.bias.data.zero_()
-            module.weight.data.fill_(1.0)
-
-    def _set_gradient_checkpointing(self, module, value=False):
-        if isinstance(module, BloomModel):
-            module.gradient_checkpointing = value
-
-
 BLOOM_START_DOCSTRING = r"""
 
     This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the

+ 4 - 1
src/client/remote_model.py

@@ -187,8 +187,11 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
     config_class = DistributedBloomConfig
 
     def __init__(self, config: DistributedBloomConfig):
-        super().__init__(config)
+        BloomPreTrainedModel.__init__(self, config)
+        self.num_labels = config.num_labels
+
         self.transformer = DistributedBloomModel(config)
+        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
 
         # Initialize weights and apply final processing
         self.post_init()

+ 0 - 2
src/server/backend.py

@@ -81,8 +81,6 @@ class TransformerBackend(ModuleBackend):
                 assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
                 assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
                 assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
-                assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
-                assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
                 cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
                 cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
                 return (hidden_states,)

+ 8 - 6
src/utils/convert_8bit.py

@@ -4,14 +4,13 @@ import torch
 
 def replace_8bit_linear(model, threshold=6.0):
     """
-    A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
+    A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
     library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
     8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
     version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
     bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
-    The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
-    be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
-    CPU/GPU memory is required to run this function.
+    The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
+    be kept as a `torch.nn.Linear` module.
     Parameters:
         model (`torch.nn.Module`):
             Input model or `torch.nn.Module` as the function is run recursively.
@@ -23,12 +22,15 @@ def replace_8bit_linear(model, threshold=6.0):
         if len(list(module.children())) > 0:
             replace_8bit_linear(module, threshold)
 
-        if isinstance(module, torch.nn.Linear) and n != "lm_head":
+        if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
             model._modules[n] = bnb.nn.Linear8bitLt(
                 module.in_features,
                 module.out_features,
                 module.bias is not None,
                 has_fp16_weights=False,
                 threshold=threshold,
-            ).to(model._modules[n].weight.device)
+            )
+            model._modules[n].weight = bnb.nn.Int8Params(
+                module.weight.data, requires_grad=False, has_fp16_weights=False
+            ).to(module.weight.dtype)
     return model