Selaa lähdekoodia

actually call backward in parallel

justheuristic 5 vuotta sitten
vanhempi
commit
9a2d661561
1 muutettua tiedostoa jossa 1 lisäystä ja 1 poistoa
  1. 1 1
      tesseract/client/moe.py

+ 1 - 1
tesseract/client/moe.py

@@ -228,7 +228,7 @@ class _RemoteMoECall(torch.autograd.Function):
 
         jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
                 for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
-        results = run_and_await_k(jobs, k=backward_k_min, timeout_after_k=None, timeout_total=backward_timeout)
+        results = run_and_await_k(jobs, k=backward_k_min, timeout_after_k=backward_timeout, timeout_total=None)
         backward_survivors_in_alive_ix, survived_grad_inputs = zip(*((i, grads) for i, grads in enumerate(results)))
         backward_survivors_in_alive_ix = torch.as_tensor(backward_survivors_in_alive_ix, device=expert_logits.device)
         backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]