|
@@ -265,22 +265,27 @@ async def _rpc_backward(
|
|
|
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
|
# Note that we do not forward for the last module since we do not need its output
|
|
|
- inter_inputs = [inputs]
|
|
|
+ inter_inputs = []
|
|
|
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
|
|
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
if not is_dummy(prompt):
|
|
|
inputs[:, :pre_seq_len] += prompt
|
|
|
+ inter_inputs.append(inputs)
|
|
|
(inputs,) = await backend.forward_pool.submit_task(inputs)
|
|
|
assert isinstance(inputs, torch.Tensor)
|
|
|
- inter_inputs.append(inputs)
|
|
|
|
|
|
- grad_prompts = []
|
|
|
+ if not is_dummy(prompts[-1]):
|
|
|
+ inputs[:, :pre_seq_len] += prompts[-1]
|
|
|
+ inter_inputs.append(inputs)
|
|
|
+
|
|
|
+ assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
|
|
|
+ grad_prompts_reversed = []
|
|
|
# Run a chain of requested backends
|
|
|
- for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
|
|
|
+ for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
|
|
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
|
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
|
if not is_dummy(prompt):
|
|
|
- grad_prompts.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
|
|
|
+ grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
|
|
|
|
|
|
- grad_prompts = torch.cat(grad_prompts, dim=0) if grad_prompts else DUMMY
|
|
|
+ grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
|
|
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
|