5
0
Эх сурвалжийг харах

Set low_cpu_mem_usage=True by default

Aleksandr Borzunov 2 жил өмнө
parent
commit
6e7565e41e
1 өөрчлөгдсөн 19 нэмэгдсэн , 6 устгасан
  1. 19 6
      src/bloom/model.py

+ 19 - 6
src/bloom/model.py

@@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
 Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
 Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
 See commit history for authorship.
 See commit history for authorship.
 """
 """
-from typing import Tuple, Union
+from typing import Optional, Tuple, Union
 
 
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
@@ -108,11 +108,24 @@ BLOOM_INPUTS_DOCSTRING = r"""
 """
 """
 
 
 
 
+class _BloomPreTrainedModelWithModifiedDefaults(BloomPreTrainedModel):
+    @classmethod
+    def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
+        if low_cpu_mem_usage is None:
+            low_cpu_mem_usage = True
+        return super().from_pretrained(*args, **kwargs)
+
+    from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
+        "low_cpu_mem_usage(`bool`, *optional*)",
+        "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
+    )
+
+
 @add_start_docstrings(
 @add_start_docstrings(
     "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
     "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
     BLOOM_START_DOCSTRING,
     BLOOM_START_DOCSTRING,
 )
 )
-class BloomModel(BloomPreTrainedModel):
+class BloomModel(_BloomPreTrainedModelWithModifiedDefaults):
     def __init__(self, config):
     def __init__(self, config):
         super().__init__(config)
         super().__init__(config)
         assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
         assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
@@ -277,7 +290,7 @@ class BloomModel(BloomPreTrainedModel):
     """,
     """,
     BLOOM_START_DOCSTRING,
     BLOOM_START_DOCSTRING,
 )
 )
-class BloomForCausalLM(BloomPreTrainedModel):
+class BloomForCausalLM(_BloomPreTrainedModelWithModifiedDefaults):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
 
     def __init__(self, config):
     def __init__(self, config):
@@ -400,8 +413,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
 @add_start_docstrings(
 @add_start_docstrings(
     """
     """
     The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
     The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
-    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. 
-    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.  
+    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
+    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
     """,
     """,
     BLOOM_START_DOCSTRING,
     BLOOM_START_DOCSTRING,
 )
 )
@@ -470,7 +483,7 @@ class LMHead(nn.Module):
     """,
     """,
     BLOOM_START_DOCSTRING,
     BLOOM_START_DOCSTRING,
 )
 )
-class BloomForSequenceClassification(BloomPreTrainedModel):
+class BloomForSequenceClassification(_BloomPreTrainedModelWithModifiedDefaults):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
 
     def __init__(self, config):
     def __init__(self, config):