justheuristic il y a 5 ans
Parent
commit
05e7c92f3d
1 fichiers modifiés avec 1 ajouts et 1 suppressions
  1. 1 1
      tesseract/client/moe.py

+ 1 - 1
tesseract/client/moe.py

@@ -216,7 +216,7 @@ class _RemoteMoECall(torch.autograd.Function):
         expert_logits, alive_ix, alive_expert_probas = ctx.saved_tensors
         alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
 
-        jobs = [partial(cls._run_expert_backward, ctx, prob, grad_outputs_flat)
+        jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
                 for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
         results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
         survived_backward, survived_grad_inputs = zip(*((alive_ix[i], grads) for i, grads in enumerate(results)))