artek0chumak 2 năm trước cách đây
mục cha
commit
a162e9d48e

+ 1 - 1
src/petals/server/backend.py

@@ -155,7 +155,7 @@ class TransformerBackend(ModuleBackend):
         for p in self.module.parameters():
             p.data = dummy
 
-    def load_adapter_(self, active_adapter: str = '') -> bool:
+    def load_adapter_(self, active_adapter: str = "") -> bool:
         """Try to make a given adapter set active if it was loaded. Return True if loaded, False if no such adapter"""
         adapter_is_loaded = False
         for layer in self.module.modules():  # select adapter set -- leave empty string for no adapter

+ 25 - 7
src/petals/server/handler.py

@@ -362,7 +362,11 @@ class TransformerConnectionHandler(ConnectionHandler):
             ), f"rpc_forward should have number of points as number or None, got {points}"
 
             hidden_states = await _rpc_forward(
-                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points
+                *flat_inputs,
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
             )
             return runtime_pb2.ExpertResponse(
                 tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
@@ -385,7 +389,11 @@ class TransformerConnectionHandler(ConnectionHandler):
             ), f"rpc_forward_stream should have number of points as number or None, got {points}"
 
             hidden_states = await _rpc_forward(
-                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points
+                *flat_inputs,
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
             )
 
             # Split the serialized_output for streaming and respond to client
@@ -432,7 +440,11 @@ class TransformerConnectionHandler(ConnectionHandler):
             ), f"rpc_backward should have number of points as number or None, got {points}"
 
             grads = await _rpc_backward(
-                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points
+                *flat_tensors,
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
             )
 
             return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
@@ -453,7 +465,11 @@ class TransformerConnectionHandler(ConnectionHandler):
             ), f"rpc_backward_stream should have number of points as number or None, got {points}"
 
             grads = await _rpc_backward(
-                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points
+                *flat_tensors,
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
             )
             # Split the serialized_grad_inputs for streaming and respond
             for tensor in self._serialize_grads(grads, requested_backends, metadata):
@@ -558,7 +574,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 async def _rpc_forward(
     *flat_tensors: torch.Tensor,
     requested_backends: Sequence[TransformerBackend],
-    active_adapter: str = '',
+    active_adapter: str = "",
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
 ) -> torch.Tensor:
@@ -590,7 +606,9 @@ async def _rpc_forward(
             hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
         )
         (hidden_states,) = await backend.forward_pool.submit_task(
-            hidden_states, active_adapter, priority=priority,
+            hidden_states,
+            active_adapter,
+            priority=priority,
         )
         assert isinstance(hidden_states, torch.Tensor)
         assert (
@@ -603,7 +621,7 @@ async def _rpc_forward(
 async def _rpc_backward(
     *flat_tensors: torch.Tensor,
     requested_backends: Sequence[TransformerBackend],
-    active_adapter: str = '',
+    active_adapter: str = "",
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:

+ 9 - 2
src/petals/server/server.py

@@ -218,7 +218,7 @@ class Server:
         self.balance_quality = balance_quality
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_block_selection_delay = mean_block_selection_delay
-        
+
         self.adapters = adapters
 
         self.stop = threading.Event()
@@ -421,7 +421,14 @@ class ModuleContainer(threading.Thread):
                     max_disk_space=max_disk_space,
                 )
                 block = convert_block(
-                    block, block_index, block_config, tensor_parallel_devices, device, quant_type, adapters=adapters, freeze=True,
+                    block,
+                    block_index,
+                    block_config,
+                    tensor_parallel_devices,
+                    device,
+                    quant_type,
+                    adapters=adapters,
+                    freeze=True,
                     use_auth_token=use_auth_token,
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,

+ 2 - 1
src/petals/utils/convert_block.py

@@ -10,10 +10,11 @@ import tensor_parallel as tp
 import torch
 import torch.nn as nn
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from petals.utils.peft import create_lora_adapter, add_adapter_to_block, load_peft
 from tensor_parallel.slicing_configs import get_bloom_config
 from transformers import PretrainedConfig
 
+from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
+
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 

+ 12 - 10
src/petals/utils/peft.py

@@ -2,9 +2,8 @@ import re
 import time
 from typing import List, Optional
 
-import torch.nn as nn
 import bitsandbytes as bnb
-
+import torch.nn as nn
 from hivemind.utils.logging import get_logger
 from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
 from peft.tuners import lora
@@ -27,7 +26,9 @@ def check_peft_repository(repo_id: str) -> bool:
 def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
     tensors = dict()
     is_tensors_found = dict()
-    common_layer_patter_re = ".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"({block_idx})?\.0\..+"
+    common_layer_patter_re = (
+        ".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"({block_idx})?\.0\..+"
+    )
     with safe_open(filepath, framework=framework, device=device) as f:
         for k in f.keys():
             if re.match(common_layer_patter_re, k):
@@ -38,9 +39,7 @@ def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", d
         return tensors
 
 
-def get_adapter_from_repo(
-    repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs
-):
+def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs):
     config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
     if config_path is None:
         raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
@@ -158,8 +157,8 @@ def create_lora_adapter(block):
                 for p in lora_wrapped_child.parameters():
                     p.requires_grad = False
                 setattr(module, child_name, lora_wrapped_child)
-                
-                
+
+
 def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
     assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
     for name, module in block.named_modules():
@@ -167,7 +166,10 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
             if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
                 continue
 
-            if child_name in peft_config["target_modules"] or (isinstance(peft_config["target_modules"], str) and re.fullmatch(peft_config["target_modules"], child_name)):
+            if child_name in peft_config["target_modules"] or (
+                isinstance(peft_config["target_modules"], str)
+                and re.fullmatch(peft_config["target_modules"], child_name)
+            ):
                 is_lora_a_loaded = False
                 is_lora_b_loaded = False
                 for peft_key in peft_state_dict:
@@ -188,6 +190,6 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
                     elif "lora_B" in peft_key:
                         child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
                         is_lora_b_loaded = True
-                        
+
                 if is_lora_a_loaded and is_lora_b_loaded:
                     logger.info(f"Loading {adapter_name} for block {block_index} is ended successfully")