Quellcode durchsuchen

Make backward more fault-tolerant

Aleksandr Borzunov vor 2 Jahren
Ursprung
Commit
a58a8b95d0
1 geänderte Dateien mit 26 neuen und 15 gelöschten Zeilen
  1. 26 15
      src/client/sequential_autograd.py

+ 26 - 15
src/client/sequential_autograd.py

@@ -59,11 +59,11 @@ async def sequential_forward(
                     logger.debug(f"Found path from block {block_idx} via {len(sequences)} servers")
 
                 span = sequences.popleft()
-                span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
 
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
 
+                span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 (outputs,) = await run_remote_forward(
                     span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout
                 )
@@ -105,11 +105,27 @@ async def sequential_backward(
 
     grad_prompts_reversed = []
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
+        inputs = intermediate_inputs.pop()
+        span = forward_sequences.pop()
         for attempt_no in itertools.count():
-            inputs = intermediate_inputs.pop(-1)
-            span = forward_sequences.pop(-1)
-            span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+            logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}")
             try:
+                if attempt_no >= 1:
+                    sequence_manager.update_()
+                    _, backup_inputs, backup_sequences = await sequential_forward(
+                        inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
+                    )
+                    assert len(backup_inputs) == len(backup_sequences)
+                    assert backup_sequences[0].start == span.start
+                    assert backup_sequences[-1].end == span.end
+
+                    intermediate_inputs.extend(backup_inputs)
+                    forward_sequences.extend(backup_sequences)
+                    inputs = intermediate_inputs.pop()
+                    span = forward_sequences.pop()
+                    break
+
+                span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 grad_outputs, *span_grad_prompts = await run_remote_backward(
                     span_uids,
@@ -124,18 +140,13 @@ async def sequential_backward(
                 grad_prompts_reversed.extend(span_grad_prompts)
                 break
             except Exception as e:
-                logger.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
-                await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
-
-                _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
-                    inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
+                delay = sequence_manager.min_backoff * 2**attempt_no
+                logger.warning(
+                    f"Caught exception when running backward between blocks {span.start}-{span.end} "
+                    f"(retry in {delay:.0f} sec): {repr(e)}"
                 )
-                assert len(intermediate_inputs) == len(forward_sequences)
-                assert backup_forward_sequences[0].start == span.start
-                assert backup_forward_sequences[-1].end == span.end
-
-                forward_sequences.extend(backup_forward_sequences)
-                intermediate_inputs.extend(backup_intermediate_inputs)
+                logger.debug("See detailed traceback below:", exc_info=True)
+                await asyncio.sleep(delay)
 
     # For now, we do not support mixed dummy and grad prompts
     # Concat in num_layer dimension