Kaynağa Gözat

grad prompts reversed

justheuristic 3 yıl önce
ebeveyn
işleme
7d22978a9e
1 değiştirilmiş dosya ile 11 ekleme ve 6 silme
  1. 11 6
      src/server/handler.py

+ 11 - 6
src/server/handler.py

@@ -265,22 +265,27 @@ async def _rpc_backward(
 
     # Run a forward chain to collect intermediate inputs
     # Note that we do not forward for the last module since we do not need its output
-    inter_inputs = [inputs]
+    inter_inputs = []
     for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
         assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
         if not is_dummy(prompt):
             inputs[:, :pre_seq_len] += prompt
+        inter_inputs.append(inputs)
         (inputs,) = await backend.forward_pool.submit_task(inputs)
         assert isinstance(inputs, torch.Tensor)
-        inter_inputs.append(inputs)
 
-    grad_prompts = []
+    if not is_dummy(prompts[-1]):
+        inputs[:, :pre_seq_len] += prompts[-1]
+    inter_inputs.append(inputs)
+
+    assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
+    grad_prompts_reversed = []
     # Run a chain of requested backends
-    for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
+    for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
         (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
-            grad_prompts.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
+            grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
 
-    grad_prompts = torch.cat(grad_prompts, dim=0) if grad_prompts else DUMMY
+    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
     return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape