Browse Source

Fix bugs in _choose_num_blocks() added in #346 (#354)

Alexander Borzunov 2 năm trước cách đây
mục cha
commit
9703358df0
2 tập tin đã thay đổi với 19 bổ sung11 xóa
  1. 14 9
      src/petals/server/server.py
  2. 5 2
      src/petals/utils/peft.py

+ 14 - 9
src/petals/server/server.py

@@ -174,9 +174,13 @@ class Server:
         self.quant_type = quant_type
         logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
 
+        # For attention cache in GPU or RAM
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
         self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
+
+        # For disk cache
         self.cache_dir = cache_dir
+        self.max_disk_space = max_disk_space
         self.adapters = adapters
 
         assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
@@ -197,9 +201,6 @@ class Server:
         logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
 
         self.alloc_timeout = alloc_timeout
-        if cache_dir is None:
-            cache_dir = DEFAULT_CACHE_DIR
-        self.max_disk_space = max_disk_space
 
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
@@ -243,20 +244,24 @@ class Server:
         else:
             total_memory = torch.cuda.get_device_properties(self.device).total_memory
 
-        block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type)
-
         gib = 1024**3
         # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
         autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size
 
-        if adapters:
+        block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type)
+        total_memory_per_block = block_size + self._cache_bytes_per_block
+        if self.adapters:
             # Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes
             from petals.utils.peft import estimate_adapter_memory_per_block
 
-            adapter_memory_per_block = estimate_adapter_memory_per_block(
-                self.block_config, self.torch_dtype, self.adapters, self.cache_dir
+            total_memory_per_block += estimate_adapter_memory_per_block(
+                self.block_config,
+                self.torch_dtype,
+                self.adapters,
+                use_auth_token=self.use_auth_token,
+                cache_dir=self.cache_dir,
+                max_disk_space=self.max_disk_space,
             )
-        total_memory_per_block = block_size + adapter_memory_per_block + self._cache_bytes_per_block
 
         num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block)
         assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"

+ 5 - 2
src/petals/utils/peft.py

@@ -217,7 +217,10 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
 
 
 def estimate_adapter_memory_per_block(
-    block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype], adapters: Sequence[str], **kwargs
+    block_config: transformers.PretrainedConfig,
+    torch_dtype: Optional[torch.dtype],
+    adapters: Sequence[str],
+    **load_peft_kwargs,
 ) -> int:
     """Get the number of extra bytes used to store a set of adapters per given block"""
     with init_empty_weights(include_buffers=True):
@@ -226,7 +229,7 @@ def estimate_adapter_memory_per_block(
         create_lora_adapter(block, quant_type=QuantType.NONE)
 
         for adapter in adapters:
-            peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **kwargs)
+            peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)
             assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now"
             add_adapter_to_block(
                 block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict