|
@@ -94,4 +94,4 @@ class _ParallelApplyFunction(torch.autograd.Function):
|
|
|
futures = [run_in_background(run_isolated_backward, func, context, *grads)
|
|
|
for context, grads in zip(contexts, grad_outputs_per_call)]
|
|
|
flat_grads_wrt_input = tuple(grad for future in futures for grad in future.result())
|
|
|
- return None, None, None, None, *flat_grads_wrt_input
|
|
|
+ return (None, None, None, None, *flat_grads_wrt_input)
|