Browse Source

make new api backwards compatible

justheuristic 3 năm trước cách đây
mục cha
commit
1ce22a950a
3 tập tin đã thay đổi với 33 bổ sung14 xóa
  1. 5 6
      src/client/sequential_autograd.py
  2. 18 0
      src/peft_utils.py
  3. 10 8
      src/server/handler.py

+ 5 - 6
src/client/sequential_autograd.py

@@ -66,6 +66,7 @@ async def run_expert_backward(
     rpc_info: RPCInfo,
     inputs: List[torch.Tensor],
     grad_outputs: List[torch.Tensor],
+    *extra_tensors: torch.Tensor,
 ) -> Sequence[torch.Tensor]:
     """
     Serializes grad outputs and calls "expert_backward".
@@ -74,13 +75,12 @@ async def run_expert_backward(
     """
 
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
-    inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu)))
+    inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
 
     # Modify forward_schema to support prompts
     args_schema, kwargs_schema = rpc_info["forward_schema"]
-    # TODO: rm this assert when support arbitrary number of input tensors
-    assert len(args_schema) == 1 and len(inputs) == 2
-    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
+    assert len(args_schema) == 1 and len(inputs) == 1
+    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)  # TODO unfuck this
 
     backward_schema = tuple(nested_flatten((forward_schema_with_prompts, rpc_info["outputs_schema"])))
 
@@ -173,9 +173,8 @@ async def sequential_backward(
                 span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
 
-                inputs_and_prompts = [inputs, prompts[span.start : span.end]]
                 grad_outputs, span_grad_prompts = await run_expert_backward(
-                    span_uids, stub, sequence_manager.rpc_info, inputs_and_prompts, grad_outputs
+                    span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts
                 )
                 grad_outputs = [grad_outputs]
                 grad_prompts.append(span_grad_prompts)

+ 18 - 0
src/peft_utils.py

@@ -0,0 +1,18 @@
+"""
+
+Generalized parameter-efficient finetuning modules that support deep prompts and several types of adapters.
+Designed to be used on both client and server side.
+
+"""
+import torch.nn as nn
+
+from src.utils.misc import DUMMY
+
+
+class GenericPEFTModule(nn.Module):
+    """Container for PEFT parameters for a single transformer block, supports multiple modes"""
+
+    def __init__(self, hidden_size: int):
+        super().__init__()
+        self.hidden_size = hidden_size
+        self.prompts = nn.Parameter(DUMMY)

+ 10 - 8
src/server/handler.py

@@ -137,11 +137,12 @@ class TransformerConnectionHandler(ConnectionHandler):
         grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
 
         # Modify grad_inputs_schema to support grad_prompts
-        assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO unfuck
+
         grad_inputs_schema_with_prompts = (
             requested_backends[0].args_schema * len(grads),
             requested_backends[0].kwargs_schema,
-        )
+        )  # TODO unfuck
 
         # Serialize the overall grad_input and respond
         return runtime_pb2.ExpertResponse(
@@ -162,11 +163,11 @@ class TransformerConnectionHandler(ConnectionHandler):
         grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
 
         # Modify grad_inputs_schema to support grad_prompts
-        assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO unfuck
         grad_inputs_schema_with_prompts = (
             requested_backends[0].args_schema * len(grads),
             requested_backends[0].kwargs_schema,
-        )
+        )  # TODO unfuck
 
         # Serialize the overall grad_inputs
         serialized_grad_inputs = [
@@ -245,11 +246,13 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
     return hidden_states
 
 
-async def _rpc_backward(inputs: torch.Tensor, prompts: torch.Tensor, grad_outputs: torch.Tensor, requested_backends):
+async def _rpc_backward(*flat_tensors: torch.Tensor, requested_backends):
+    inputs, grad_outputs, *prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype
     inputs = inputs.to(requested_backends[0].dtype)
-    prompts = prompts.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
+    NO_PROMPTS = not prompts
+    prompts = prompts.to(requested_backends[0].dtype) if prompts else DUMMY
 
     if is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
@@ -283,5 +286,4 @@ async def _rpc_backward(inputs: torch.Tensor, prompts: torch.Tensor, grad_output
 
     is_dummy_grad_prompts = [is_dummy(grad_param) for grad_param in grad_prompts]
     grad_prompts = torch.cat(grad_prompts, dim=0) if not any(is_dummy_grad_prompts) else DUMMY
-    grads = [grad_outputs, grad_prompts]
-    return grads
+    return [grad_outputs] if NO_PROMPTS else [grad_outputs, grad_prompts]  # TODO un-duct-tape