|
@@ -77,7 +77,7 @@ class _ParallelApplyFunction(torch.autograd.Function):
|
|
|
assert num_calls * num_args_per_call == len(args_flat)
|
|
|
args_per_call = [args_flat[i * num_args_per_call: (i + 1) * num_args_per_call] for i in range(num_calls)]
|
|
|
|
|
|
- futures = [run_in_background(run_isolated_backward, func, *args) for args in args_per_call]
|
|
|
+ futures = [run_in_background(run_isolated_forward, func, *args) for args in args_per_call]
|
|
|
|
|
|
outputs, contexts = zip(*[future.result() for future in futures])
|
|
|
output_strides = np.cumsum([0] + list(map(len, outputs)))
|