ソースを参照

use lists for gatehr

justheuristic 5 年 前
コミット
8030c075c9
1 ファイル変更3 行追加3 行削除
  1. 3 3
      tesseract/client/moe.py

+ 3 - 3
tesseract/client/moe.py

@@ -195,7 +195,7 @@ class _RemoteMoECall(torch.autograd.Function):
         #       \-- a list of autograd contexts, used for parallel backward
 
         # 2. compute softmax weights for alive experts and average outputs
-        alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
+        alive_expert_probs = torch.softmax(expert_logits[list(alive_ix)], dim=0)
 
         flat_average_outputs = tuple(map(
             lambda *tensors: sum(x * weight for x, weight in zip(tensors, alive_expert_probs)), *alive_outputs))
@@ -219,8 +219,8 @@ class _RemoteMoECall(torch.autograd.Function):
         results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
         survived_backward, survived_grad_inputs = zip(*((alive_ix[i], grads) for i, grads in enumerate(results)))
 
-        survived_ix = alive_ix[survived_backward]
-        survived_expert_probas = torch.softmax(expert_logits[survived_ix], dim=0)
+        survived_ix = alive_ix[list(survived_backward)]
+        survived_expert_probas = torch.softmax(expert_logits[list(survived_ix)], dim=0)
 
         flat_grad_inputs = tuple(map(
             lambda *tensors: sum(x * weight for x, weight in zip(tensors, survived_expert_probas)),