|
@@ -91,7 +91,7 @@ class _ParallelApplyFunction(torch.autograd.Function):
|
|
|
def backward(ctx, *grad_outputs_flat: torch.Tensor):
|
|
|
func, contexts, output_strides = ctx._inner_func, ctx._call_contexts, ctx._output_strides
|
|
|
grad_outputs_per_call = [grad_outputs_flat[output_strides[i]: output_strides[i + 1]] for i in range(len(contexts))]
|
|
|
- futures = [run_in_background(run_isolated_backward(func, context, *grads))
|
|
|
+ futures = [run_in_background(run_isolated_backward, func, context, *grads)
|
|
|
for context, grads in zip(contexts, grad_outputs_per_call)]
|
|
|
flat_grads_wrt_input = tuple(grad for future in futures for grad in future.result())
|
|
|
return None, None, None, None, *flat_grads_wrt_input
|