Selaa lähdekoodia

Estimate adapter memory overhead in choose_num_blocks() (#346)

* estimate adapter memory overhead
* reduce number of heads based on that

---------

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
justheuristic 2 vuotta sitten
vanhempi
commit
010857a834
3 muutettua tiedostoa jossa 45 lisäystä ja 9 poistoa
  1. 13 4
      src/petals/server/server.py
  2. 0 2
      src/petals/utils/convert_block.py
  3. 32 3
      src/petals/utils/peft.py

+ 13 - 4
src/petals/server/server.py

@@ -30,6 +30,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
+from petals.utils.peft import estimate_adapter_memory_per_block
 from petals.utils.version import get_compatible_model_repo
 
 logger = get_logger(__name__)
@@ -176,6 +177,8 @@ class Server:
 
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
         self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
+        self.cache_dir = cache_dir
+        self.adapters = adapters
 
         assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
         if num_blocks is None and block_indices is None:
@@ -197,7 +200,6 @@ class Server:
         self.alloc_timeout = alloc_timeout
         if cache_dir is None:
             cache_dir = DEFAULT_CACHE_DIR
-        self.cache_dir = cache_dir
         self.max_disk_space = max_disk_space
 
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
@@ -219,8 +221,6 @@ class Server:
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_block_selection_delay = mean_block_selection_delay
 
-        self.adapters = adapters
-
         self.stop = threading.Event()
 
     def _choose_num_blocks(self) -> int:
@@ -250,7 +250,16 @@ class Server:
         # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
         autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size
 
-        num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
+        if adapters:
+            # Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes
+            from petals.utils.peft import estimate_adapter_memory_per_block
+
+            adapter_memory_per_block = estimate_adapter_memory_per_block(
+                self.block_config, self.torch_dtype, self.adapters, self.cache_dir
+            )
+        total_memory_per_block = block_size + adapter_memory_per_block + self._cache_bytes_per_block
+
+        num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block)
         assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
 
         num_blocks = min(num_blocks, self.block_config.num_hidden_layers)

+ 0 - 2
src/petals/utils/convert_block.py

@@ -55,8 +55,6 @@ def convert_block(
         shard.to(device)
 
     if adapters:
-        # Import petals.utils.peft only when necessary to avoid importing bitsandbytes
-        os.environ["BITSANDBYTES_NOWELCOME"] = "1"
         from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
 
         create_lora_adapter(block, quant_type=quant_type)

+ 32 - 3
src/petals/utils/peft.py

@@ -1,9 +1,16 @@
+import os
 import re
 import time
-from typing import List, Optional
+from typing import List, Optional, Sequence
+
+os.environ["BITSANDBYTES_NOWELCOME"] = "1"
 
 import bitsandbytes as bnb
+import peft
+import torch
 import torch.nn as nn
+import transformers
+from accelerate import init_empty_weights
 from hivemind.utils.logging import get_logger
 from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
 from peft.tuners import lora
@@ -12,6 +19,8 @@ from safetensors import safe_open
 from safetensors.torch import load_file
 from transformers.utils import get_file_from_repo
 
+from petals.client.ptune import force_non_empty_weights
+from petals.server.block_utils import resolve_block_dtype
 from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
 from petals.utils.misc import QuantType
 
@@ -194,15 +203,35 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
                             p.requires_grad = False
 
                     if peft_key.endswith(".lora_A.weight"):
-                        child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key]
+                        child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key]
                         is_lora_a_loaded = True
                     elif peft_key.endswith(".lora_A.bias"):
                         raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
                     elif peft_key.endswith(".lora_B.weight"):
-                        child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
+                        child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key]
                         is_lora_b_loaded = True
                     elif peft_key.endswith(".lora_B.bias"):
                         raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
 
                 if is_lora_a_loaded and is_lora_b_loaded:
                     logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")
+
+
+def estimate_adapter_memory_per_block(
+    block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype], adapters: Sequence[str], **kwargs
+) -> int:
+    """Get the number of extra bytes used to store a set of adapters per given block"""
+    with init_empty_weights(include_buffers=True):
+        block = block_config.block_class(block_config)
+        base_block_parameters = sum(p.numel() for p in block.parameters())
+        create_lora_adapter(block, quant_type=QuantType.NONE)
+
+        for adapter in adapters:
+            peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **kwargs)
+            assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now"
+            add_adapter_to_block(
+                block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
+            )
+        adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
+    bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
+    return adapter_parameters * bytes_per_parameter