فهرست منبع

fix prompts for partial chains

justheuristic 3 سال پیش
والد
کامیت
0791f854f8
1فایلهای تغییر یافته به همراه6 افزوده شده و 8 حذف شده
  1. 6 8
      src/client/sequential_autograd.py

+ 6 - 8
src/client/sequential_autograd.py

@@ -163,21 +163,19 @@ async def sequential_backward(
     """
     assert len(intermediate_inputs) == len(forward_sequences)
 
-    grad_prompts = []
+    grad_prompts_reversed = []
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
         while True:
+            inputs = intermediate_inputs.pop(-1)
+            span = forward_sequences.pop(-1)
+            span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start: span.end])
             try:
-                inputs = intermediate_inputs.pop(-1)
-                span = forward_sequences.pop(-1)
-
-                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-
                 grad_outputs, *span_grad_prompts = await run_expert_backward(
                     span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start: span.end]
                 )
                 grad_outputs = [grad_outputs]
-                grad_prompts.extend(span_grad_prompts)
+                grad_prompts_reversed.extend(span_grad_prompts)
                 break
             except Exception as e:
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
@@ -193,7 +191,7 @@ async def sequential_backward(
 
     # For now, we do not support mixed dummy and grad prompts
     # Concat in num_layer dimension
-    grad_prompts = torch.cat(grad_prompts, dim=0) if grad_prompts else None
+    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
     return grad_outputs, grad_prompts