|
@@ -163,21 +163,19 @@ async def sequential_backward(
|
|
|
"""
|
|
|
assert len(intermediate_inputs) == len(forward_sequences)
|
|
|
|
|
|
- grad_prompts = []
|
|
|
+ grad_prompts_reversed = []
|
|
|
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
|
|
|
while True:
|
|
|
+ inputs = intermediate_inputs.pop(-1)
|
|
|
+ span = forward_sequences.pop(-1)
|
|
|
+ span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start: span.end])
|
|
|
try:
|
|
|
- inputs = intermediate_inputs.pop(-1)
|
|
|
- span = forward_sequences.pop(-1)
|
|
|
-
|
|
|
- span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
-
|
|
|
grad_outputs, *span_grad_prompts = await run_expert_backward(
|
|
|
span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start: span.end]
|
|
|
)
|
|
|
grad_outputs = [grad_outputs]
|
|
|
- grad_prompts.extend(span_grad_prompts)
|
|
|
+ grad_prompts_reversed.extend(span_grad_prompts)
|
|
|
break
|
|
|
except Exception as e:
|
|
|
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
|
|
@@ -193,7 +191,7 @@ async def sequential_backward(
|
|
|
|
|
|
# For now, we do not support mixed dummy and grad prompts
|
|
|
# Concat in num_layer dimension
|
|
|
- grad_prompts = torch.cat(grad_prompts, dim=0) if grad_prompts else None
|
|
|
+ grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
|
|
|
return grad_outputs, grad_prompts
|
|
|
|
|
|
|