|
@@ -272,13 +272,11 @@ async def _rpc_backward(
|
|
|
inputs[:, :pre_seq_len] += prompt
|
|
|
(inputs,) = await backend.forward_pool.submit_task(inputs)
|
|
|
assert isinstance(inputs, torch.Tensor)
|
|
|
- inter_inputs.append(inputs.clone()) # TODO optimize: reduce the number of copies
|
|
|
+ inter_inputs.append(inputs)
|
|
|
|
|
|
grad_prompts = []
|
|
|
# Run a chain of requested backends
|
|
|
for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
|
|
|
- if not is_dummy(prompt):
|
|
|
- inp[:, :pre_seq_len] += prompt
|
|
|
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
|
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
|
if not is_dummy(prompt):
|