浏览代码

Fix deps, enable 8-bit by default for TP (#298)

This PR fixes issues of #290:

- hivemind bfloat16 codec crashed on dummy tensors (with 0 elements), see https://github.com/learning-at-home/hivemind/pull/560 (this PR makes Petals depend on the latest hivemind version from the repo, it's temporary)
- transformers version check mismatched with the version allowed in `setup.cfg`

Also:

- This PR enables 8-bit by default for TP. Even though TP in 8-bit may be slower, we currently prefer to host more blocks to increase the network's stability.
Alexander Borzunov 2 年之前
父节点
当前提交
2116df08bc
共有 3 个文件被更改,包括 3 次插入9 次删除
  1. 1 1
      setup.cfg
  2. 2 2
      src/petals/bloom/block.py
  3. 0 6
      src/petals/server/server.py

+ 1 - 1
setup.cfg

@@ -37,7 +37,7 @@ install_requires =
     huggingface-hub==0.11.1
     transformers>=4.25.1,<5.0.0
     speedtest-cli==2.1.3
-    hivemind==1.1.6
+    hivemind @ git+https://github.com/learning-at-home/hivemind.git
     tensor_parallel==1.0.23
     humanfriendly
     async-timeout>=4.0.2

+ 2 - 2
src/petals/bloom/block.py

@@ -13,8 +13,8 @@ from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
     assert (
-        version.parse("4.26.0") < version.parse(transformers.__version__) < version.parse("5.0.0")
-    ), "Please install a proper transformers version: pip install transformers>=4.26.0,<5.0.0"
+        version.parse("4.25.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
+    ), "Please install a proper transformers version: pip install transformers>=4.25.1,<5.0.0"
 
 
 class WrappedBloomBlock(BloomBlock):

+ 0 - 6
src/petals/server/server.py

@@ -163,12 +163,6 @@ class Server:
 
         if load_in_8bit is None:
             load_in_8bit = device.type == "cuda"
-            if load_in_8bit and len(self.tensor_parallel_devices) > 1:
-                load_in_8bit = False
-                logger.warning(
-                    "Tensor parallelism doesn't work properly with 8-bit weights yet, loading weights in 16-bit. "
-                    "You can explicitly set `--load_in_8bit True` to override this"
-                )
         self.load_in_8bit = load_in_8bit
         logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")