Browse Source

Remove load_in_8bit from convert_block

Max Ryabinin 2 năm trước cách đây
mục cha
commit
a610f4d744
1 tập tin đã thay đổi với 4 bổ sung2 xóa
  1. 4 2
      src/petals/server/throughput.py

+ 4 - 2
src/petals/server/throughput.py

@@ -13,7 +13,7 @@ from transformers import BloomConfig
 
 from petals.bloom.block import WrappedBloomBlock
 from petals.server.block_utils import resolve_block_dtype
-from petals.utils.convert_block import convert_block
+from petals.utils.convert_block import convert_block, replace_8bit_linear
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
 logger = get_logger(__name__)
@@ -149,7 +149,9 @@ def measure_compute_rps(
         tensor_parallel_devices = (device,)
     with torch.inference_mode():
         block = WrappedBloomBlock(config).to(dtype)
-        block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True)
+        if load_in_8bit:
+            block = replace_8bit_linear(block)
+        block = convert_block(block, config, tensor_parallel_devices, device, freeze=True)
 
         cache = None
         elapsed = 0