|
@@ -195,7 +195,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
# \-- a list of autograd contexts, used for parallel backward
|
|
|
|
|
|
# 2. compute softmax weights for alive experts and average outputs
|
|
|
- alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
|
|
|
+ alive_expert_probs = torch.softmax(expert_logits[list(alive_ix)], dim=0)
|
|
|
|
|
|
flat_average_outputs = tuple(map(
|
|
|
lambda *tensors: sum(x * weight for x, weight in zip(tensors, alive_expert_probs)), *alive_outputs))
|
|
@@ -219,8 +219,8 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
|
|
|
survived_backward, survived_grad_inputs = zip(*((alive_ix[i], grads) for i, grads in enumerate(results)))
|
|
|
|
|
|
- survived_ix = alive_ix[survived_backward]
|
|
|
- survived_expert_probas = torch.softmax(expert_logits[survived_ix], dim=0)
|
|
|
+ survived_ix = alive_ix[list(survived_backward)]
|
|
|
+ survived_expert_probas = torch.softmax(expert_logits[list(survived_ix)], dim=0)
|
|
|
|
|
|
flat_grad_inputs = tuple(map(
|
|
|
lambda *tensors: sum(x * weight for x, weight in zip(tensors, survived_expert_probas)),
|