浏览代码

Force transformers to use config.torch_dtype by default (#307)

Alexander Borzunov 2 年之前
父节点
当前提交
c0e0e1319d
共有 1 个文件被更改,包括 19 次插入6 次删除
  1. 19 6
      src/petals/client/remote_model.py

+ 19 - 6
src/petals/client/remote_model.py

@@ -71,20 +71,33 @@ def force_non_empty_weights():
         nn.Module.register_parameter = possibly_patched_register_parameter
 
 
-class _LowCPUMemoryMixin:
+class _FromPretrainedDefaultsMixin:
     @classmethod
-    def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
+    def from_pretrained(
+        cls,
+        *args,
+        low_cpu_mem_usage: Optional[bool] = None,
+        torch_dtype: Optional[Union[str, torch.dtype]] = None,
+        **kwargs,
+    ):
         if low_cpu_mem_usage is None:
             low_cpu_mem_usage = True
-        return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
+        if torch_dtype is None:
+            # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
+            # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
+            torch_dtype = "auto"
+        return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs)
 
     from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
         "low_cpu_mem_usage(`bool`, *optional*)",
         "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
+    ).replace(
+        "torch_dtype (`str` or `torch.dtype`, *optional*)",
+        'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
     )
 
 
-class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
+class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel):
     """BloomModel, but all transformer layers are hosted by the swarm"""
 
     _keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
@@ -218,7 +231,7 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
         )
 
 
-class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM):
+class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
 
     _keys_to_ignore_on_load_missing = (
@@ -256,7 +269,7 @@ class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, Blo
             self.lm_head.bias[...] = new_lm_head.bias
 
 
-class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequenceClassification):
+class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification):
     _keys_to_ignore_on_load_missing = (
         BloomForSequenceClassification._keys_to_ignore_on_load_missing
         + DistributedBloomModel._keys_to_ignore_on_load_missing