Selaa lähdekoodia

Support saving and loading 8-bit block weights

Max Ryabinin 2 vuotta sitten
vanhempi
commit
b624f9048e

+ 7 - 1
src/petals/bloom/from_pretrained.py

@@ -23,11 +23,12 @@ from transformers.utils import get_file_from_repo
 from petals.bloom.block import WrappedBloomBlock
 from petals.server.block_utils import get_block_size
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
+from petals.utils.convert_block import replace_8bit_linear
 
 logger = get_logger(__name__)
 
 CLIENT_BRANCH = "main"
-BLOCK_BRANCH_PREFIX = "block_"
+BLOCK_BRANCH_PREFIX = "int8_block"
 
 
 def load_pretrained_block(
@@ -38,6 +39,8 @@ def load_pretrained_block(
     use_auth_token: Optional[str] = None,
     cache_dir: Optional[str] = None,
     max_disk_space: Optional[int] = None,
+    load_in_8bit=False,
+    device: Optional[Union[str, torch.device]] = None,
 ) -> WrappedBloomBlock:
     """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
     assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
@@ -49,6 +52,9 @@ def load_pretrained_block(
 
     with init_empty_weights():
         block = WrappedBloomBlock(config)
+    if load_in_8bit:
+        block = replace_8bit_linear(block)
+        block = block.to(device)
 
     state_dict = _load_state_dict(
         converted_model_name_or_path,

+ 13 - 6
src/petals/cli/convert_model.py

@@ -15,16 +15,17 @@ from petals.client import DistributedBloomConfig
 
 logger = get_logger(__name__)
 
-DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
-
+DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, int8=torch.int8, auto="auto")
 
 def main():
     parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
 
     parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
     parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
-    parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
-    parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
+    parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
+                        help="Load initial model in this dtype")
+    parser.add_argument("--output_path", type=str, default="./converted_model",
+                        help="Track output repo to this folder")
     parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
     parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
     parser.add_argument(
@@ -41,7 +42,6 @@ def main():
     if args.model == "bigscience/bloom" and free_ram_gb < 400:
         logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
 
-    assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
     if os.path.exists(args.output_path) and (
         len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
     ):
@@ -54,8 +54,15 @@ def main():
     config.dht_prefix = args.output_repo
 
     model = BloomModel.from_pretrained(
-        args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
+        args.model, use_auth_token=args.use_auth_token, revision=args.revision,
+        torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16",
+        load_in_8bit=args.torch_dtype == "int8",
+        device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"}
     )
+    if args.torch_dtype == "int8":
+        # trigger weight quantization
+        model = model.cuda()
+
     if args.resize_token_embeddings:
         logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
         model.resize_token_embeddings(args.resize_token_embeddings)

+ 3 - 1
src/petals/server/server.py

@@ -401,8 +401,10 @@ class ModuleContainer(threading.Thread):
                     use_auth_token=use_auth_token,
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
+                    load_in_8bit=load_in_8bit,
+                    device=device,
                 )
-                block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
+                block = convert_block(block, block_config, tensor_parallel_devices, device, freeze=True)
 
                 backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
                 blocks[module_uid] = TransformerBackend(

+ 1 - 8
src/petals/utils/convert_block.py

@@ -24,12 +24,10 @@ def convert_block(
     config: BloomConfig,
     tensor_parallel_devices: Sequence[torch.device],
     output_device: torch.device,
-    load_in_8bit: bool,
-    threshold: float = 6.0,
     freeze: bool = True,
 ) -> tp.TensorParallel:
     """
-    Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
+    Optimize a transformer block for use in a Petals server and apply tensor parallelism
 
     :note: some optimizations will modify the input block in-place!
     :param block: a single transformer block, either pre-trained or newly initialized
@@ -37,8 +35,6 @@ def convert_block(
     :param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
     :note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
     :param output_device: if tensor_parallel_devices is True, output
-    :param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint
-    :param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 )
     :param freeze: if True (default), make all module parameters non-trainable
     :return: a module that acts like the original block, but runs with all specified optimizations
 
@@ -49,9 +45,6 @@ def convert_block(
 
     block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
 
-    if load_in_8bit:
-        block = replace_8bit_linear(block, threshold=threshold)
-
     for shard, device in zip(block.module_shards, block.devices):
         shard.to(device)