|
@@ -59,11 +59,11 @@ async def sequential_forward(
|
|
|
logger.debug(f"Found path from block {block_idx} via {len(sequences)} servers")
|
|
|
|
|
|
span = sequences.popleft()
|
|
|
- span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
|
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
|
|
|
|
|
|
+ span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
(outputs,) = await run_remote_forward(
|
|
|
span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout
|
|
|
)
|
|
@@ -105,11 +105,27 @@ async def sequential_backward(
|
|
|
|
|
|
grad_prompts_reversed = []
|
|
|
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
|
|
|
+ inputs = intermediate_inputs.pop()
|
|
|
+ span = forward_sequences.pop()
|
|
|
for attempt_no in itertools.count():
|
|
|
- inputs = intermediate_inputs.pop(-1)
|
|
|
- span = forward_sequences.pop(-1)
|
|
|
- span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
+ logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}")
|
|
|
try:
|
|
|
+ if attempt_no >= 1:
|
|
|
+ sequence_manager.update_()
|
|
|
+ _, backup_inputs, backup_sequences = await sequential_forward(
|
|
|
+ inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
|
|
|
+ )
|
|
|
+ assert len(backup_inputs) == len(backup_sequences)
|
|
|
+ assert backup_sequences[0].start == span.start
|
|
|
+ assert backup_sequences[-1].end == span.end
|
|
|
+
|
|
|
+ intermediate_inputs.extend(backup_inputs)
|
|
|
+ forward_sequences.extend(backup_sequences)
|
|
|
+ inputs = intermediate_inputs.pop()
|
|
|
+ span = forward_sequences.pop()
|
|
|
+ break
|
|
|
+
|
|
|
+ span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
grad_outputs, *span_grad_prompts = await run_remote_backward(
|
|
|
span_uids,
|
|
@@ -124,18 +140,13 @@ async def sequential_backward(
|
|
|
grad_prompts_reversed.extend(span_grad_prompts)
|
|
|
break
|
|
|
except Exception as e:
|
|
|
- logger.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
|
|
|
- await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
|
|
|
-
|
|
|
- _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
|
|
|
- inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
|
|
|
+ delay = sequence_manager.min_backoff * 2**attempt_no
|
|
|
+ logger.warning(
|
|
|
+ f"Caught exception when running backward between blocks {span.start}-{span.end} "
|
|
|
+ f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
|
)
|
|
|
- assert len(intermediate_inputs) == len(forward_sequences)
|
|
|
- assert backup_forward_sequences[0].start == span.start
|
|
|
- assert backup_forward_sequences[-1].end == span.end
|
|
|
-
|
|
|
- forward_sequences.extend(backup_forward_sequences)
|
|
|
- intermediate_inputs.extend(backup_intermediate_inputs)
|
|
|
+ logger.debug("See detailed traceback below:", exc_info=True)
|
|
|
+ await asyncio.sleep(delay)
|
|
|
|
|
|
# For now, we do not support mixed dummy and grad prompts
|
|
|
# Concat in num_layer dimension
|