Browse Source

Override float32 in config to bfloat16 (#431)

Alexander Borzunov 2 năm trước cách đây
mục cha
commit
00d48dcbe1
2 tập tin đã thay đổi với 4 bổ sung3 xóa
  1. 2 2
      README.md
  2. 2 1
      src/petals/server/block_utils.py

+ 2 - 2
README.md

@@ -48,7 +48,7 @@ Petals is a community-run system — we rely on people sharing their GPUs. Y
 ```bash
 conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
 pip install git+https://github.com/bigscience-workshop/petals
-python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16
+python -m petals.cli.run_server stabilityai/StableBeluga2
 ```
 
 🪟 **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows).
@@ -57,7 +57,7 @@ python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16
 
 ```bash
 sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \
-    python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2 --torch_dtype float16
+    python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2
 ```
 
 These commands will host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, repos with LLaMA-65B, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures.

+ 2 - 1
src/petals/server/block_utils.py

@@ -11,7 +11,8 @@ def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]
     """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
     if dtype not in ("auto", None):
         return dtype
-    if config.torch_dtype not in ("auto", None):
+    if config.torch_dtype not in ("auto", None, torch.float32):
+        # If config specifies float32, we override it to the default dtype below
         return config.torch_dtype
     return torch.bfloat16