Quellcode durchsuchen

select adapter by name in handler.py

artek0chumak vor 2 Jahren
Ursprung
Commit
1e227240e5
3 geänderte Dateien mit 38 neuen und 9 gelöschten Zeilen
  1. 1 0
      src/petals/data_structures.py
  2. 22 0
      src/petals/server/backend.py
  3. 15 9
      src/petals/server/handler.py

+ 1 - 0
src/petals/data_structures.py

@@ -57,3 +57,4 @@ class InferenceMetadata:
     uid: ExpertUID
     prefix_length: int
     cache_handles: Tuple[Handle, ...]
+    active_adapter: Optional[str]

+ 22 - 0
src/petals/server/backend.py

@@ -4,6 +4,7 @@ from collections import Counter
 from itertools import chain
 from typing import Any, Dict, Optional, Sequence, Tuple
 
+import peft
 import torch
 from hivemind import BatchTensorDescriptor, TensorDescriptor
 from hivemind.moe.expert_uid import ExpertUID
@@ -80,6 +81,14 @@ class TransformerBackend(ModuleBackend):
             cache_tensors.extend((keys, values))
         return cache_tensors
 
+    def forward(self, active_adapter: str, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+        self.load_adapter_(active_adapter)  # empty string means remove any adapters
+        return super().forward(*inputs)
+
+    def backward(self, active_adapter: str, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+        self.load_adapter_(active_adapter)  # empty string means remove any adapters
+        return super().backward(*inputs)
+
     @torch.inference_mode()
     def inference_step(
         self,
@@ -88,6 +97,9 @@ class TransformerBackend(ModuleBackend):
         inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+
+        if not self.load_adapter_(inference_info.active_adapter):
+            raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
         with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
             self._reorder_cache_inplace(cache_tensors, hypo_ids)
             layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
@@ -139,6 +151,16 @@ class TransformerBackend(ModuleBackend):
         for p in self.module.parameters():
             p.data = dummy
 
+    def load_adapter_(self, active_adapter: str = '') -> bool:
+        """Try to make a given adapter set active if it was loaded. Return True if loaded, False if no such adapter"""
+        adapter_is_loaded = False
+        for layer in self.module.modules():  # select adapter set -- leave empty string for no adapter
+            if isinstance(layer, peft.tuners.lora.Linear):
+                layer.active_adapter = active_adapter  # empty string for no adapter
+                if active_adapter in layer.lora_A.keys():
+                    adapter_is_loaded = True
+        return adapter_is_loaded
+
 
 def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
     """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""

+ 15 - 9
src/petals/server/handler.py

@@ -141,6 +141,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
                 requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
                 max_length = metadata.get("max_length")
+                active_adapter = metadata.get("active_adapter", "")
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
 
@@ -201,7 +202,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                         )
 
                         inference_infos = tuple(
-                            InferenceMetadata(uid, prefix_length, tuple(handles))
+                            InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
                             for uid, handles in zip(requested_uids, cache_handles)
                         )
 
@@ -354,13 +355,14 @@ class TransformerConnectionHandler(ConnectionHandler):
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+            active_adapter = metadata.get("active_adapter", "")
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)
             ), f"rpc_forward should have number of points as number or None, got {points}"
 
             hidden_states = await _rpc_forward(
-                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points
             )
             return runtime_pb2.ExpertResponse(
                 tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
@@ -376,13 +378,14 @@ class TransformerConnectionHandler(ConnectionHandler):
             self._log_request("rpc_forward_stream", requested_uids, context)
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            active_adapter = metadata.get("active_adapter", "")
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)
             ), f"rpc_forward_stream should have number of points as number or None, got {points}"
 
             hidden_states = await _rpc_forward(
-                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points
             )
 
             # Split the serialized_output for streaming and respond to client
@@ -422,13 +425,14 @@ class TransformerConnectionHandler(ConnectionHandler):
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+            active_adapter = metadata.get("active_adapter", "")
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)
             ), f"rpc_backward should have number of points as number or None, got {points}"
 
             grads = await _rpc_backward(
-                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points
             )
 
             return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
@@ -442,13 +446,14 @@ class TransformerConnectionHandler(ConnectionHandler):
             self._log_request("rpc_backward_stream", requested_uids, context)
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            active_adapter = metadata.get("active_adapter", "")
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)
             ), f"rpc_backward_stream should have number of points as number or None, got {points}"
 
             grads = await _rpc_backward(
-                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points
             )
             # Split the serialized_grad_inputs for streaming and respond
             for tensor in self._serialize_grads(grads, requested_backends, metadata):
@@ -553,6 +558,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 async def _rpc_forward(
     *flat_tensors: torch.Tensor,
     requested_backends: Sequence[TransformerBackend],
+    active_adapter: str = '',
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
 ) -> torch.Tensor:
@@ -584,8 +590,7 @@ async def _rpc_forward(
             hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
         )
         (hidden_states,) = await backend.forward_pool.submit_task(
-            hidden_states,
-            priority=priority,
+            active_adapter, hidden_states, priority=priority,
         )
         assert isinstance(hidden_states, torch.Tensor)
         assert (
@@ -598,6 +603,7 @@ async def _rpc_forward(
 async def _rpc_backward(
     *flat_tensors: torch.Tensor,
     requested_backends: Sequence[TransformerBackend],
+    active_adapter: str = '',
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
@@ -623,7 +629,7 @@ async def _rpc_backward(
         priority = prioritizer.prioritize(
             inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
         )
-        (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
+        (inputs,) = await backend.forward_pool.submit_task(active_adapter, inputs, priority=priority)
 
         assert isinstance(inputs, torch.Tensor)
 
@@ -639,7 +645,7 @@ async def _rpc_backward(
         priority = prioritizer.prioritize(
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
         )
-        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
+        (grad_outputs,) = await backend.backward_pool.submit_task(active_adapter, inp, grad_outputs, priority=priority)
 
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):