Kaynağa Gözat

rm redundant code

dbaranchuk 3 yıl önce
ebeveyn
işleme
4fb5e680ee
1 değiştirilmiş dosya ile 1 ekleme ve 36 silme
  1. 1 36
      src/bloom/model.py

+ 1 - 36
src/bloom/model.py

@@ -23,6 +23,7 @@ from transformers.modeling_outputs import (
 )
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig
+from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
 from transformers.utils import logging
 
 from src.bloom.block import BloomBlock
@@ -35,42 +36,6 @@ _CONFIG_FOR_DOC = "BloomConfig"
 _TOKENIZER_FOR_DOC = "BloomTokenizer"
 
 
-class BloomPreTrainedModel(PreTrainedModel):
-    _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
-    """
-    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
-    models.
-    """
-
-    config_class = BloomConfig
-    base_model_prefix = "transformer"
-    supports_gradient_checkpointing = True
-    _no_split_modules = ["BloomBlock"]
-
-    def __init__(self, *inputs, **kwargs):
-        super().__init__(*inputs, **kwargs)
-
-    def _init_weights(self, module):
-        """Initialize the weights."""
-        if isinstance(module, (nn.Linear)):
-            # Slightly different from the TF version which uses truncated_normal for initialization
-            # cf https://github.com/pytorch/pytorch/pull/5617
-            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
-            if module.bias is not None:
-                module.bias.data.zero_()
-        elif isinstance(module, nn.Embedding):
-            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
-            if module.padding_idx is not None:
-                module.weight.data[module.padding_idx].zero_()
-        elif isinstance(module, LayerNorm):
-            module.bias.data.zero_()
-            module.weight.data.fill_(1.0)
-
-    def _set_gradient_checkpointing(self, module, value=False):
-        if isinstance(module, BloomModel):
-            module.gradient_checkpointing = value
-
-
 BLOOM_START_DOCSTRING = r"""
 
     This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the