|
@@ -228,7 +228,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
|
|
|
jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
|
|
jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
|
|
for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
|
|
for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
|
|
- results = run_and_await_k(jobs, k=backward_k_min, timeout_after_k=None, timeout_total=backward_timeout)
|
|
|
|
|
|
+ results = run_and_await_k(jobs, k=backward_k_min, timeout_after_k=backward_timeout, timeout_total=None)
|
|
backward_survivors_in_alive_ix, survived_grad_inputs = zip(*((i, grads) for i, grads in enumerate(results)))
|
|
backward_survivors_in_alive_ix, survived_grad_inputs = zip(*((i, grads) for i, grads in enumerate(results)))
|
|
backward_survivors_in_alive_ix = torch.as_tensor(backward_survivors_in_alive_ix, device=expert_logits.device)
|
|
backward_survivors_in_alive_ix = torch.as_tensor(backward_survivors_in_alive_ix, device=expert_logits.device)
|
|
backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
|
|
backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
|