5
0
justheuristic 3 жил өмнө
parent
commit
e364a1e2bf

+ 7 - 8
src/client/sequential_autograd.py

@@ -64,7 +64,7 @@ async def run_expert_backward(
     uid: ModuleUID,
     stub: StubBase,
     rpc_info: RPCInfo,
-    inputs: List[torch.Tensor],
+    inputs: torch.Tensor,
     grad_outputs: List[torch.Tensor],
     *extra_tensors: torch.Tensor,
 ) -> Sequence[torch.Tensor]:
@@ -79,10 +79,10 @@ async def run_expert_backward(
 
     # Modify forward_schema to support prompts
     args_schema, kwargs_schema = rpc_info["forward_schema"]
-    assert len(args_schema) == 1 and len(inputs) == 1
-    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)  # TODO generalize this
-
-    backward_schema = tuple(nested_flatten((forward_schema_with_prompts, rpc_info["outputs_schema"])))
+    assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
+    # TODO generalize this
+    prompts_schema = next(iter(args_schema))
+    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
 
     # Asynchronous serialization
     loop = asyncio.get_running_loop()
@@ -152,9 +152,9 @@ async def sequential_forward(
 
 async def sequential_backward(
     grad_outputs: Sequence[torch.Tensor],
-    intermediate_inputs: Sequence[torch.Tensor],
+    intermediate_inputs: List[torch.Tensor],
     prompts: Sequence[torch.Tensor],
-    forward_sequences: Sequence[RemoteSpanInfo],
+    forward_sequences: List[RemoteSpanInfo],
     sequence_manager: RemoteSequenceManager,
 ) -> Sequence[torch.Tensor]:
     """
@@ -191,7 +191,6 @@ async def sequential_backward(
                 forward_sequences.extend(backup_forward_sequences)
                 intermediate_inputs.extend(backup_intermediate_inputs)
 
-    dummy_grad_prompts = [is_dummy(grad_prompt) for grad_prompt in grad_prompts]
     # For now, we do not support mixed dummy and grad prompts
     # Concat in num_layer dimension
     grad_prompts = torch.cat(grad_prompts, dim=0) if grad_prompts else None

+ 4 - 3
src/server/handler.py

@@ -225,10 +225,11 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
     # check parse input tensors and cast dtypes
     hidden_states = hidden_states.to(dtype)
     assert hidden_states.ndim == 3
-    if not prompts or len(prompts) == 1 and is_dummy(prompts[0]):
+    if not prompts or is_dummy(prompts[0]):
         prompts = [DUMMY] * len(requested_backends)
         pre_seq_len = 0
     else:
+        prompts = [prompts[0].to(requested_backends[0].dtype)]
         prompts = [p.squeeze(0) for p in prompts[0].split(1)]
         pre_seq_len = prompts[0].shape[-2]
 
@@ -253,12 +254,12 @@ async def _rpc_backward(
     # Cast inputs & grad outputs to backend dtype
     inputs = inputs.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
-    prompts = prompts.to(requested_backends[0].dtype) if prompts else DUMMY
 
-    if is_dummy(prompts):
+    if not prompts or is_dummy(prompts[0]):
         prompts = [DUMMY] * len(requested_backends)
         pre_seq_len = 0
     else:
+        prompts = [prompts[0].to(requested_backends[0].dtype)]
         prompts = [p.squeeze(0) for p in prompts[0].split(1)]
         pre_seq_len = prompts[0].shape[-2]