|
@@ -23,6 +23,7 @@ from transformers.modeling_outputs import (
|
|
|
)
|
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
|
from transformers.models.bloom.configuration_bloom import BloomConfig
|
|
|
+from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
|
|
|
from transformers.utils import logging
|
|
|
|
|
|
from src.bloom.block import BloomBlock
|
|
@@ -35,42 +36,6 @@ _CONFIG_FOR_DOC = "BloomConfig"
|
|
|
_TOKENIZER_FOR_DOC = "BloomTokenizer"
|
|
|
|
|
|
|
|
|
-class BloomPreTrainedModel(PreTrainedModel):
|
|
|
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
|
|
- """
|
|
|
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
|
- models.
|
|
|
- """
|
|
|
-
|
|
|
- config_class = BloomConfig
|
|
|
- base_model_prefix = "transformer"
|
|
|
- supports_gradient_checkpointing = True
|
|
|
- _no_split_modules = ["BloomBlock"]
|
|
|
-
|
|
|
- def __init__(self, *inputs, **kwargs):
|
|
|
- super().__init__(*inputs, **kwargs)
|
|
|
-
|
|
|
- def _init_weights(self, module):
|
|
|
- """Initialize the weights."""
|
|
|
- if isinstance(module, (nn.Linear)):
|
|
|
- # Slightly different from the TF version which uses truncated_normal for initialization
|
|
|
- # cf https://github.com/pytorch/pytorch/pull/5617
|
|
|
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
|
- if module.bias is not None:
|
|
|
- module.bias.data.zero_()
|
|
|
- elif isinstance(module, nn.Embedding):
|
|
|
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
|
- if module.padding_idx is not None:
|
|
|
- module.weight.data[module.padding_idx].zero_()
|
|
|
- elif isinstance(module, LayerNorm):
|
|
|
- module.bias.data.zero_()
|
|
|
- module.weight.data.fill_(1.0)
|
|
|
-
|
|
|
- def _set_gradient_checkpointing(self, module, value=False):
|
|
|
- if isinstance(module, BloomModel):
|
|
|
- module.gradient_checkpointing = value
|
|
|
-
|
|
|
-
|
|
|
BLOOM_START_DOCSTRING = r"""
|
|
|
|
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|