Alexander Borzunov 2 роки тому
батько
коміт
e12d4c666b

+ 2 - 0
src/petals/__init__.py

@@ -1,5 +1,7 @@
 import os
 
+os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
+
 import hivemind
 import transformers
 from packaging import version

+ 8 - 9
src/petals/server/backend.py

@@ -82,14 +82,12 @@ class TransformerBackend(ModuleBackend):
 
     def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
         *inputs, active_adapter = inputs
-        if not self.load_adapter_(active_adapter):
-            raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded")
+        self.load_adapter_(active_adapter)
         return super().forward(*inputs)
 
     def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
         *inputs, active_adapter = inputs
-        if not self.load_adapter_(active_adapter):
-            raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded")
+        self.load_adapter_(active_adapter)
         return super().backward(*inputs)
 
     @torch.inference_mode()
@@ -100,8 +98,7 @@ class TransformerBackend(ModuleBackend):
         inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
-        if not self.load_adapter_(inference_info.active_adapter):
-            raise KeyError(f"Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
+        self.load_adapter_(inference_info.active_adapter)
         with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
             self._reorder_cache_inplace(cache_tensors, hypo_ids)
             layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
@@ -159,13 +156,15 @@ class TransformerBackend(ModuleBackend):
         # Import petals.utils.peft only when necessary to avoid importing bitsandbytes
         from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt
 
-        adapter_was_loaded = False
+        loaded = False
         for layer in self.module.modules():  # select adapter set -- leave empty string for no adapter
             if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)):
                 layer.active_adapter = active_adapter  # empty string for no adapter
                 if active_adapter in layer.lora_A.keys():
-                    adapter_was_loaded = True
-        return adapter_was_loaded or not active_adapter
+                    loaded = True
+
+        if active_adapter and not loaded:
+            raise KeyError(f"Could not find adapter {active_adapter}, perhaps it is not loaded")
 
 
 def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):

+ 2 - 2
src/petals/server/handler.py

@@ -307,10 +307,10 @@ class TransformerConnectionHandler(ConnectionHandler):
         """Directly push activation tensors from one server to another"""
 
         requested_uids = self._check_uids(request.uid)
-        self._log_request("rpc_push", requested_uids, context)
-
         metadata = MSGPackSerializer.loads(request.metadata)
         session_id = metadata["session_id"]
+        self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}")
+
         self._session_queues[session_id].put(request)
         return runtime_pb2.ExpertResponse()
 

+ 0 - 1
src/petals/utils/convert_block.py

@@ -71,7 +71,6 @@ def convert_block(
 
 def quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Module:
     # Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
-    os.environ["BITSANDBYTES_NOWELCOME"] = "1"
     import bitsandbytes as bnb
 
     for n, module in model.named_children():

+ 11 - 12
src/petals/utils/peft.py

@@ -1,12 +1,8 @@
-import os
 import re
 import time
-from typing import List, Optional, Sequence
-
-os.environ["BITSANDBYTES_NOWELCOME"] = "1"
+from typing import Optional, Sequence
 
 import bitsandbytes as bnb
-import peft
 import torch
 import torch.nn as nn
 import transformers
@@ -19,7 +15,6 @@ from safetensors import safe_open
 from safetensors.torch import load_file
 from transformers.utils import get_file_from_repo
 
-from petals.client.ptune import force_non_empty_weights
 from petals.server.block_utils import resolve_block_dtype
 from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
 from petals.utils.misc import QuantType
@@ -124,7 +119,7 @@ def load_peft(
 
 
 def create_lora_adapter(block, quant_type: QuantType):
-    for name, module in block.named_modules():
+    for _, module in block.named_modules():
         for child_name, child in module.named_children():
             lora_wrapped_child = None
             if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
@@ -173,7 +168,10 @@ def create_lora_adapter(block, quant_type: QuantType):
 
 def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
     assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
-    for name, module in block.named_modules():
+    if peft_config["lora_dropout"] > 0:
+        logger.info(f"Adapter {adapter_name} has dropout enabled, this server will disable dropout")
+
+    for _, module in block.named_modules():
         for child_name, child in module.named_children():
             if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
                 continue
@@ -185,7 +183,7 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
                 is_lora_a_loaded = False
                 is_lora_b_loaded = False
                 for peft_key in peft_state_dict:
-                    if peft_key.find(child_name) == -1:
+                    if child_name not in peft_key:
                         continue
 
                     if adapter_name not in child.lora_A:
@@ -197,8 +195,6 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
                             init_lora_weights=peft_config["init_lora_weights"],
                         )
                         child.train(False)
-                        if peft_config["lora_dropout"] > 0:
-                            logger.warning("Loading LoRA config with dropout enabled; this server will disable dropout")
                         for p in child.parameters():
                             p.requires_grad = False
 
@@ -214,7 +210,10 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
                         raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
 
                 if is_lora_a_loaded and is_lora_b_loaded:
-                    logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")
+                    logger.debug(f"Loaded adapter {adapter_name} for block {block_index}.{child_name}")
+                elif is_lora_a_loaded or is_lora_b_loaded:
+                    raise ValueError(f"Invalid adapter {adapter_name} for block {block_index}.{child_name}")
+    logger.info(f"Loaded adapter {adapter_name} for block {block_index}")
 
 
 def estimate_adapter_memory_per_block(