Browse Source

list -> tensor

justheuristic 5 years ago
parent
commit
c8889bde96
1 changed files with 5 additions and 4 deletions
  1. 5 4
      tesseract/client/moe.py

+ 5 - 4
tesseract/client/moe.py

@@ -195,7 +195,8 @@ class _RemoteMoECall(torch.autograd.Function):
         #       \-- a list of autograd contexts, used for parallel backward
 
         # 2. compute softmax weights for alive experts and average outputs
-        alive_expert_probs = torch.softmax(expert_logits[list(alive_ix)], dim=0)
+        alive_ix = torch.as_tensor(alive_ix, device=expert_logits.device)
+        alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
 
         flat_average_outputs = tuple(map(
             lambda *tensors: sum(x * weight for x, weight in zip(tensors, alive_expert_probs)), *alive_outputs))
@@ -218,9 +219,9 @@ class _RemoteMoECall(torch.autograd.Function):
                 for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
         results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
         survived_backward, survived_grad_inputs = zip(*((alive_ix[i], grads) for i, grads in enumerate(results)))
-
-        survived_ix = alive_ix[list(survived_backward)]
-        survived_expert_probas = torch.softmax(expert_logits[list(survived_ix)], dim=0)
+        survived_backward = torch.as_tensor(survived_backward, device=expert_logits.device)
+        survived_ix = alive_ix[survived_backward]
+        survived_expert_probas = torch.softmax(expert_logits[survived_ix], dim=0)
 
         flat_grad_inputs = tuple(map(
             lambda *tensors: sum(x * weight for x, weight in zip(tensors, survived_expert_probas)),