|
@@ -229,6 +229,8 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
|
|
|
|
|
|
if is_dummy(prompts):
|
|
|
prompts = [make_dummy_batch(batch_size)] * len(requested_backends)
|
|
|
+ else:
|
|
|
+ prompts = [p.squeeze(0) for p in prompts.split(1)]
|
|
|
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
|
# Note that we do not forward for the last module since we do not need its output
|
|
@@ -241,7 +243,7 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
|
|
|
|
|
|
grad_prompts = []
|
|
|
# Run a chain of requested backends
|
|
|
- for inp, prompt, backend in zip(inter_inputs[::-1], prompts.flip(0), requested_backends[::-1]):
|
|
|
+ for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
|
|
|
grads = await backend.backward_pool.submit_task(inp, prompt, grad_outputs)
|
|
|
assert isinstance(grads, (list, tuple)) and len(grads) == 2
|
|
|
grad_outputs, grad_prompt = grads
|