|
@@ -72,11 +72,14 @@ async def sequential_forward(
|
|
|
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
|
|
|
|
|
|
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
- metadata = sequence_manager.get_request_metadata(
|
|
|
- 'rpc_forward', span_uids, *inputs_and_prompts)
|
|
|
+ metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
|
|
|
(outputs,) = await run_remote_forward(
|
|
|
- span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts,
|
|
|
- timeout=sequence_manager.timeout, metadata=metadata
|
|
|
+ span_uids,
|
|
|
+ stub,
|
|
|
+ sequence_manager.rpc_info,
|
|
|
+ *inputs_and_prompts,
|
|
|
+ timeout=sequence_manager.timeout,
|
|
|
+ metadata=metadata,
|
|
|
)
|
|
|
|
|
|
assert isinstance(outputs, torch.Tensor)
|
|
@@ -150,7 +153,8 @@ async def sequential_backward(
|
|
|
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
metadata = sequence_manager.get_request_metadata(
|
|
|
- 'rpc_backward', span_uids, *inputs, *grad_outputs, peer_id=span.peer_id)
|
|
|
+ "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
|
|
|
+ )
|
|
|
grad_outputs, *span_grad_prompts = await run_remote_backward(
|
|
|
span_uids,
|
|
|
stub,
|