浏览代码

Fix OOMs happening in case of accelerate >= 0.16.0 (#310)

- After #285, `load_pretrained_block()` uses `accelerate.utils.set_module_tensor_to_device()`
- In accelerate>=0.16.0, it saves the tensor in the dtype previously used by the model instead of dtype of the weights (https://github.com/huggingface/accelerate/pull/920)
- Because of that, blocks and attention caches used float32, which caused OOMs
- This PR makes `load_pretrained_block()` respect `torch_dtype` (default: `"auto"`, which means reading `torch_dtype` from `config.json`)
Alexander Borzunov 2 年之前
父节点
当前提交
454c193863
共有 2 个文件被更改,包括 2 次插入2 次删除
  1. 1 1
      setup.cfg
  2. 1 1
      src/petals/bloom/from_pretrained.py

+ 1 - 1
setup.cfg

@@ -33,7 +33,7 @@ python_requires = >=3.7
 install_requires =
     torch>=1.12
     bitsandbytes==0.38.0.post2
-    accelerate>=0.15.0,<1.0.0
+    accelerate>=0.16.0,<1.0.0
     huggingface-hub>=0.11.1,<1.0.0
     transformers>=4.25.1,<5.0.0
     speedtest-cli==2.1.3

+ 1 - 1
src/petals/bloom/from_pretrained.py

@@ -68,7 +68,7 @@ def load_pretrained_block(
         param = state_dict[param_name]
         if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
             param = param.to(torch_dtype)
-        set_module_tensor_to_device(block, param_name, "cpu", value=param)
+        set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
 
     logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
     return block