Pārlūkot izejas kodu

Switch adapters slightly faster (#353)

Currently, each `TransformerBackend.inference_step` looks for adapters and sets the correct adapter type for each block. This is not very expensive, but it can measurably affect inference time.

This pull request uses faster adapter switching with just one variable assignment, without iterating over block.modules().
justheuristic 2 gadi atpakaļ
vecāks
revīzija
37fdcb3fe0

+ 13 - 22
src/petals/server/backend.py

@@ -24,9 +24,15 @@ logger = get_logger(__name__)
 class TransformerBackend(ModuleBackend):
     """A wrapper for a transformer block that can process requests for forward, backward and inference"""
 
+    _peft_module = None
+
     def __init__(
         self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs
     ):
+        import petals.utils.peft as _peft_module
+
+        self._peft_module = _peft_module
+
         super().__init__(*args, **kwargs)
         assert isinstance(self.module, TensorParallel)
         self.config = config
@@ -82,13 +88,13 @@ class TransformerBackend(ModuleBackend):
 
     def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
         *inputs, active_adapter = inputs
-        self.load_adapter_(active_adapter)
-        return super().forward(*inputs)
+        with self._peft_module.using_adapter(active_adapter):
+            return super().forward(*inputs)
 
     def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
         *inputs, active_adapter = inputs
-        self.load_adapter_(active_adapter)
-        return super().backward(*inputs)
+        with self._peft_module.using_adapter(active_adapter):
+            return super().backward(*inputs)
 
     @torch.inference_mode()
     def inference_step(
@@ -98,8 +104,9 @@ class TransformerBackend(ModuleBackend):
         inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
-        self.load_adapter_(inference_info.active_adapter)
-        with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
+        with self.memory_cache.use_cache(
+            *inference_info.cache_handles
+        ) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
             self._reorder_cache_inplace(cache_tensors, hypo_ids)
             layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
             hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
@@ -150,22 +157,6 @@ class TransformerBackend(ModuleBackend):
         for p in self.module.parameters():
             p.data = dummy
 
-    def load_adapter_(self, active_adapter: Optional[str] = None) -> bool:
-        """Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""
-
-        # Import petals.utils.peft only when necessary to avoid importing bitsandbytes
-        from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt
-
-        loaded = False
-        for layer in self.module.modules():  # select adapter set -- leave empty string for no adapter
-            if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)):
-                layer.active_adapter = active_adapter  # empty string for no adapter
-                if active_adapter in layer.lora_A.keys():
-                    loaded = True
-
-        if active_adapter and not loaded:
-            raise KeyError(f"Could not find adapter {active_adapter}, perhaps it is not loaded")
-
 
 def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
     """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""

+ 13 - 5
src/petals/server/handler.py

@@ -68,6 +68,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         dht: DHT,
         module_backends: Dict[str, TransformerBackend],
         *,
+        adapters: Optional[Sequence[str]],
         dht_prefix: str,
         push_manager: multiprocessing.managers.SyncManager,
         session_queues: Dict[str, multiprocessing.managers.BaseProxy],  # BaseProxy for queue.Queue
@@ -81,6 +82,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         for module_backend in self.module_backends.values():
             assert isinstance(module_backend, TransformerBackend)
         self.dht_prefix = dht_prefix
+        self.adapters = adapters
         self._push_manager = push_manager
         self._session_queues = session_queues
         self._executor = ThreadPoolExecutor(max_workers=float("inf"))  # For waiting on self.session_queues
@@ -141,7 +143,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
                 requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
                 max_length = metadata.get("max_length")
-                active_adapter = metadata.get("active_adapter", "")
+                active_adapter = self._get_active_adapter(metadata)
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
 
@@ -355,7 +357,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-            active_adapter = metadata.get("active_adapter", "")
+            active_adapter = self._get_active_adapter(metadata)
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)
@@ -382,7 +384,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             self._log_request("rpc_forward_stream", requested_uids, context)
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-            active_adapter = metadata.get("active_adapter", "")
+            active_adapter = self._get_active_adapter(metadata)
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)
@@ -433,7 +435,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-            active_adapter = metadata.get("active_adapter", "")
+            active_adapter = self._get_active_adapter(metadata)
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)
@@ -458,7 +460,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             self._log_request("rpc_backward_stream", requested_uids, context)
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-            active_adapter = metadata.get("active_adapter", "")
+            active_adapter = self._get_active_adapter(metadata)
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)
@@ -476,6 +478,12 @@ class TransformerConnectionHandler(ConnectionHandler):
                 for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
                     yield runtime_pb2.ExpertResponse(tensors=[part])
 
+    def _get_active_adapter(self, metadata: dict) -> str:
+        active_adapter = metadata.get("active_adapter", "")
+        if active_adapter and (active_adapter not in self.adapters):
+            raise KeyError(f"adapter {active_adapter} not found")
+        return active_adapter
+
     def _serialize_grads(
         self,
         grads: Sequence[torch.Tensor],

+ 1 - 0
src/petals/server/server.py

@@ -534,6 +534,7 @@ class ModuleContainer(threading.Thread):
             TransformerConnectionHandler(
                 dht,
                 self.module_backends,
+                adapters=adapters,
                 dht_prefix=dht_prefix,
                 push_manager=self.push_manager,
                 session_queues=session_queues,

+ 48 - 7
src/petals/utils/peft.py

@@ -1,3 +1,4 @@
+import contextlib
 import re
 import time
 from typing import Optional, Sequence
@@ -118,6 +119,47 @@ def load_peft(
             time.sleep(delay)
 
 
+class AdapterContextMixin:
+    """A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""
+
+    ADAPTER_NOT_SET = "__ADAPTER_NOT_SET"
+    _context_active_adapter = ADAPTER_NOT_SET
+
+    @staticmethod
+    @contextlib.contextmanager
+    def using_adapter(active_adapter: Optional[str]):
+        prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter
+        try:
+            yield
+        finally:
+            AdapterContextMixin._context_active_adapter = prev
+
+    @property
+    def active_adapter(self):
+        if self._context_active_adapter == self.ADAPTER_NOT_SET:
+            logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug")
+        return self._context_active_adapter
+
+    @active_adapter.setter
+    def active_adapter(self, value: Optional[str]):
+        assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""
+
+
+using_adapter = AdapterContextMixin.using_adapter
+
+
+class LoraLinear(lora.Linear, AdapterContextMixin):
+    """LoRA linear layer that uses adapter selected via using_adapter"""
+
+
+class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin):
+    """LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
+
+
+class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin):
+    """LoRA linear 4-bit that uses adapter selected via using_adapter"""
+
+
 def create_lora_adapter(block, quant_type: QuantType):
     for _, module in block.named_modules():
         for child_name, child in module.named_children():
@@ -130,8 +172,8 @@ def create_lora_adapter(block, quant_type: QuantType):
                     "threshold": 6.0,
                     "bias": hasattr(child, "bias") and child.bias is not None,
                 }
-                lora_wrapped_child = lora.Linear8bitLt(
-                    child_name,
+                lora_wrapped_child = LoraLinear8bitLt(
+                    AdapterContextMixin.ADAPTER_NOT_SET,
                     child.in_features,
                     child.out_features,
                     **kwargs,
@@ -143,22 +185,21 @@ def create_lora_adapter(block, quant_type: QuantType):
                     "blocksize": 64,
                     "bias": hasattr(child, "bias") and child.bias is not None,
                 }
-                lora_wrapped_child = lora.Linear4bit(
-                    child_name,
+                lora_wrapped_child = LoraLinear4bit(
+                    AdapterContextMixin.ADAPTER_NOT_SET,
                     child.in_features,
                     child.out_features,
                     **kwargs,
                 )
             else:
                 bias = hasattr(child, "bias") and child.bias is not None
-                lora_wrapped_child = lora.Linear(
-                    child_name,
+                lora_wrapped_child = LoraLinear(
+                    AdapterContextMixin.ADAPTER_NOT_SET,
                     child.in_features,
                     child.out_features,
                     bias=bias,
                 )
             if lora_wrapped_child:
-                lora_wrapped_child.active_adapter = None
                 lora_wrapped_child.weight = child.weight
                 lora_wrapped_child.bias = child.bias
                 for p in lora_wrapped_child.parameters():