Forráskód Böngészése

wip: parallel fault-tolerant moe backward pass

justheuristic 5 éve
szülő
commit
80ab75583f
1 módosított fájl, 1 hozzáadás és 1 törlés
  1. 1 1
      tesseract/client/moe.py

+ 1 - 1
tesseract/client/moe.py

@@ -217,7 +217,7 @@ class _RemoteMoECall(torch.autograd.Function):
         jobs = [partial(cls._run_expert_backward, ctx, prob, grad_outputs_flat)
                 for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
         results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
-        survived_backward, survived_grad_inputs = zip(*(alive_ix[i], grads for i, grads in enumerate(results)))
+        survived_backward, survived_grad_inputs = zip(*((alive_ix[i], grads) for i, grads in enumerate(results)))
 
         survived_ix = alive_ix[survived_backward]
         survived_expert_probas = torch.softmax(expert_logits[survived_ix], dim=0)