소스 검색

Set low_cpu_mem_usage=True by default

Aleksandr Borzunov 2 년 전
부모
커밋
6e7565e41e
1개의 변경된 파일19개의 추가작업 그리고 6개의 파일을 삭제
  1. 19 6
      src/bloom/model.py

+ 19 - 6
src/bloom/model.py

@@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
 Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
 See commit history for authorship.
 """
-from typing import Tuple, Union
+from typing import Optional, Tuple, Union
 
 import torch
 import torch.nn.functional as F
@@ -108,11 +108,24 @@ BLOOM_INPUTS_DOCSTRING = r"""
 """
 
 
+class _BloomPreTrainedModelWithModifiedDefaults(BloomPreTrainedModel):
+    @classmethod
+    def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
+        if low_cpu_mem_usage is None:
+            low_cpu_mem_usage = True
+        return super().from_pretrained(*args, **kwargs)
+
+    from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
+        "low_cpu_mem_usage(`bool`, *optional*)",
+        "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
+    )
+
+
 @add_start_docstrings(
     "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
     BLOOM_START_DOCSTRING,
 )
-class BloomModel(BloomPreTrainedModel):
+class BloomModel(_BloomPreTrainedModelWithModifiedDefaults):
     def __init__(self, config):
         super().__init__(config)
         assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
@@ -277,7 +290,7 @@ class BloomModel(BloomPreTrainedModel):
     """,
     BLOOM_START_DOCSTRING,
 )
-class BloomForCausalLM(BloomPreTrainedModel):
+class BloomForCausalLM(_BloomPreTrainedModelWithModifiedDefaults):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
     def __init__(self, config):
@@ -400,8 +413,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
 @add_start_docstrings(
     """
     The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
-    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. 
-    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.  
+    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
+    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
     """,
     BLOOM_START_DOCSTRING,
 )
@@ -470,7 +483,7 @@ class LMHead(nn.Module):
     """,
     BLOOM_START_DOCSTRING,
 )
-class BloomForSequenceClassification(BloomPreTrainedModel):
+class BloomForSequenceClassification(_BloomPreTrainedModelWithModifiedDefaults):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
     def __init__(self, config):