|
@@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
|
|
|
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
|
|
See commit history for authorship.
|
|
|
"""
|
|
|
-from typing import Tuple, Union
|
|
|
+from typing import Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
@@ -108,11 +108,24 @@ BLOOM_INPUTS_DOCSTRING = r"""
|
|
|
"""
|
|
|
|
|
|
|
|
|
+class _BloomPreTrainedModelWithModifiedDefaults(BloomPreTrainedModel):
|
|
|
+ @classmethod
|
|
|
+ def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
|
|
|
+ if low_cpu_mem_usage is None:
|
|
|
+ low_cpu_mem_usage = True
|
|
|
+ return super().from_pretrained(*args, **kwargs)
|
|
|
+
|
|
|
+ from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
|
|
|
+ "low_cpu_mem_usage(`bool`, *optional*)",
|
|
|
+ "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
@add_start_docstrings(
|
|
|
"The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
|
|
|
BLOOM_START_DOCSTRING,
|
|
|
)
|
|
|
-class BloomModel(BloomPreTrainedModel):
|
|
|
+class BloomModel(_BloomPreTrainedModelWithModifiedDefaults):
|
|
|
def __init__(self, config):
|
|
|
super().__init__(config)
|
|
|
assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
|
|
@@ -277,7 +290,7 @@ class BloomModel(BloomPreTrainedModel):
|
|
|
""",
|
|
|
BLOOM_START_DOCSTRING,
|
|
|
)
|
|
|
-class BloomForCausalLM(BloomPreTrainedModel):
|
|
|
+class BloomForCausalLM(_BloomPreTrainedModelWithModifiedDefaults):
|
|
|
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
|
|
|
|
|
def __init__(self, config):
|
|
@@ -400,8 +413,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|
|
@add_start_docstrings(
|
|
|
"""
|
|
|
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
|
|
|
- embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
|
|
|
- In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
|
|
|
+ embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
|
|
|
+ In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
|
|
|
""",
|
|
|
BLOOM_START_DOCSTRING,
|
|
|
)
|
|
@@ -470,7 +483,7 @@ class LMHead(nn.Module):
|
|
|
""",
|
|
|
BLOOM_START_DOCSTRING,
|
|
|
)
|
|
|
-class BloomForSequenceClassification(BloomPreTrainedModel):
|
|
|
+class BloomForSequenceClassification(_BloomPreTrainedModelWithModifiedDefaults):
|
|
|
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
|
|
|
|
|
def __init__(self, config):
|