浏览代码

actually call backward in parallel

justheuristic 5 年之前
父节点
当前提交
1606fca863
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      tesseract/utils/autograd.py

+ 1 - 1
tesseract/utils/autograd.py

@@ -91,7 +91,7 @@ class _ParallelApplyFunction(torch.autograd.Function):
     def backward(ctx, *grad_outputs_flat: torch.Tensor):
         func, contexts, output_strides = ctx._inner_func, ctx._call_contexts, ctx._output_strides
         grad_outputs_per_call = [grad_outputs_flat[output_strides[i]: output_strides[i + 1]] for i in range(len(contexts))]
-        futures = [run_in_background(run_isolated_backward(func, context, *grads))
+        futures = [run_in_background(run_isolated_backward, func, context, *grads)
                    for context, grads in zip(contexts, grad_outputs_per_call)]
         flat_grads_wrt_input = tuple(grad for future in futures for grad in future.result())
         return None, None, None, None, *flat_grads_wrt_input