|
@@ -153,7 +153,7 @@ async def sequential_forward(
|
|
async def sequential_backward(
|
|
async def sequential_backward(
|
|
grad_outputs: Sequence[torch.Tensor],
|
|
grad_outputs: Sequence[torch.Tensor],
|
|
intermediate_inputs: List[torch.Tensor],
|
|
intermediate_inputs: List[torch.Tensor],
|
|
- prompts: Sequence[torch.Tensor],
|
|
|
|
|
|
+ prompts: torch.Tensor,
|
|
forward_sequences: List[RemoteSpanInfo],
|
|
forward_sequences: List[RemoteSpanInfo],
|
|
sequence_manager: RemoteSequenceManager,
|
|
sequence_manager: RemoteSequenceManager,
|
|
) -> Sequence[torch.Tensor]:
|
|
) -> Sequence[torch.Tensor]:
|
|
@@ -174,7 +174,7 @@ async def sequential_backward(
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
|
|
|
grad_outputs, *span_grad_prompts = await run_expert_backward(
|
|
grad_outputs, *span_grad_prompts = await run_expert_backward(
|
|
- span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts
|
|
|
|
|
|
+ span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start: span.end]
|
|
)
|
|
)
|
|
grad_outputs = [grad_outputs]
|
|
grad_outputs = [grad_outputs]
|
|
grad_prompts.extend(span_grad_prompts)
|
|
grad_prompts.extend(span_grad_prompts)
|
|
@@ -182,7 +182,7 @@ async def sequential_backward(
|
|
except Exception as e:
|
|
except Exception as e:
|
|
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
|
|
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
|
|
_, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
|
|
_, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
|
|
- inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
|
|
|
|
|
|
+ inputs, prompts[span.start: span.end], sequence_manager, start_index=span.start, end_index=span.end
|
|
)
|
|
)
|
|
assert len(intermediate_inputs) == len(forward_sequences)
|
|
assert len(intermediate_inputs) == len(forward_sequences)
|
|
assert backup_forward_sequences[0].start == span.start
|
|
assert backup_forward_sequences[0].start == span.start
|