|
@@ -134,10 +134,9 @@ async def run_rpc_backward(
|
|
|
priority = prioritizer.prioritize(
|
|
|
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
|
|
)
|
|
|
- (grad_outputs,), grad_kwargs = await backend.backward_pool.submit_task(
|
|
|
+ (*grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task(
|
|
|
active_adapter, grad_outputs, inp, **kwargs, priority=priority, size=num_tokens
|
|
|
)
|
|
|
-
|
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
|
if not is_dummy(prompt):
|
|
|
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|