Browse Source

fix bug: send adapter name last instead of first

artek0chumak 2 years ago
parent
commit
784168e8f5
2 changed files with 8 additions and 6 deletions
  1. 5 3
      src/petals/server/backend.py
  2. 3 3
      src/petals/server/handler.py

+ 5 - 3
src/petals/server/backend.py

@@ -2,7 +2,7 @@ from __future__ import annotations
 
 from collections import Counter
 from itertools import chain
-from typing import Any, Dict, Optional, Sequence, Tuple
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
 
 import peft
 import torch
@@ -81,12 +81,14 @@ class TransformerBackend(ModuleBackend):
             cache_tensors.extend((keys, values))
         return cache_tensors
 
-    def forward(self, active_adapter: str, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+    def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
+        *inputs, active_adapter = inputs
         if active_adapter and not self.load_adapter_(active_adapter):
             raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
         return super().forward(*inputs)
 
-    def backward(self, active_adapter: str, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+    def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
+        *inputs, active_adapter = inputs
         if active_adapter and not self.load_adapter_(active_adapter):
             raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
         return super().backward(*inputs)

+ 3 - 3
src/petals/server/handler.py

@@ -590,7 +590,7 @@ async def _rpc_forward(
             hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
         )
         (hidden_states,) = await backend.forward_pool.submit_task(
-            active_adapter, hidden_states, priority=priority,
+            hidden_states, active_adapter, priority=priority,
         )
         assert isinstance(hidden_states, torch.Tensor)
         assert (
@@ -629,7 +629,7 @@ async def _rpc_backward(
         priority = prioritizer.prioritize(
             inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
         )
-        (inputs,) = await backend.forward_pool.submit_task(active_adapter, inputs, priority=priority)
+        (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
 
         assert isinstance(inputs, torch.Tensor)
 
@@ -645,7 +645,7 @@ async def _rpc_backward(
         priority = prioritizer.prioritize(
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
         )
-        (grad_outputs,) = await backend.backward_pool.submit_task(active_adapter, inp, grad_outputs, priority=priority)
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
 
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):