|
@@ -64,7 +64,7 @@ async def run_expert_backward(
|
|
|
uid: ModuleUID,
|
|
|
stub: StubBase,
|
|
|
rpc_info: RPCInfo,
|
|
|
- inputs: List[torch.Tensor],
|
|
|
+ inputs: torch.Tensor,
|
|
|
grad_outputs: List[torch.Tensor],
|
|
|
*extra_tensors: torch.Tensor,
|
|
|
) -> Sequence[torch.Tensor]:
|
|
@@ -79,10 +79,10 @@ async def run_expert_backward(
|
|
|
|
|
|
# Modify forward_schema to support prompts
|
|
|
args_schema, kwargs_schema = rpc_info["forward_schema"]
|
|
|
- assert len(args_schema) == 1 and len(inputs) == 1
|
|
|
- forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema) # TODO generalize this
|
|
|
-
|
|
|
- backward_schema = tuple(nested_flatten((forward_schema_with_prompts, rpc_info["outputs_schema"])))
|
|
|
+ assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
|
|
|
+ # TODO generalize this
|
|
|
+ prompts_schema = next(iter(args_schema))
|
|
|
+ backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
|
|
|
|
|
|
# Asynchronous serialization
|
|
|
loop = asyncio.get_running_loop()
|
|
@@ -152,9 +152,9 @@ async def sequential_forward(
|
|
|
|
|
|
async def sequential_backward(
|
|
|
grad_outputs: Sequence[torch.Tensor],
|
|
|
- intermediate_inputs: Sequence[torch.Tensor],
|
|
|
+ intermediate_inputs: List[torch.Tensor],
|
|
|
prompts: Sequence[torch.Tensor],
|
|
|
- forward_sequences: Sequence[RemoteSpanInfo],
|
|
|
+ forward_sequences: List[RemoteSpanInfo],
|
|
|
sequence_manager: RemoteSequenceManager,
|
|
|
) -> Sequence[torch.Tensor]:
|
|
|
"""
|
|
@@ -191,7 +191,6 @@ async def sequential_backward(
|
|
|
forward_sequences.extend(backup_forward_sequences)
|
|
|
intermediate_inputs.extend(backup_intermediate_inputs)
|
|
|
|
|
|
- dummy_grad_prompts = [is_dummy(grad_prompt) for grad_prompt in grad_prompts]
|
|
|
# For now, we do not support mixed dummy and grad prompts
|
|
|
# Concat in num_layer dimension
|
|
|
grad_prompts = torch.cat(grad_prompts, dim=0) if grad_prompts else None
|