|
@@ -216,7 +216,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
expert_logits, alive_ix, alive_expert_probas = ctx.saved_tensors
|
|
expert_logits, alive_ix, alive_expert_probas = ctx.saved_tensors
|
|
alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
|
|
alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
|
|
|
|
|
|
- jobs = [partial(cls._run_expert_backward, ctx, prob, grad_outputs_flat)
|
|
|
|
|
|
+ jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
|
|
for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
|
|
for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
|
|
results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
|
|
results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
|
|
survived_backward, survived_grad_inputs = zip(*((alive_ix[i], grads) for i, grads in enumerate(results)))
|
|
survived_backward, survived_grad_inputs = zip(*((alive_ix[i], grads) for i, grads in enumerate(results)))
|