justheuristic 5 years ago
parent
commit
fe728568b1
1 changed files with 2 additions and 2 deletions
  1. 2 2
      tesseract/client/moe.py

+ 2 - 2
tesseract/client/moe.py

@@ -79,8 +79,8 @@ class RemoteMixtureOfExperts(nn.Module):
         flat_inputs_per_expert = tuple(zip(*[tensor.split(1, dim=0) for tensor in nested_flatten(expert_inputs)]))
 
         batch_jobs_args = tuple(
-            (expert_logits[i], chosen_experts[i], self.k_min, self.timeout_after_k_min, self.backward_k_min,
-             self.forward_timeout, self.backward_timeout, input_schema, *flat_inputs_per_expert[i])
+            (expert_logits[i, :len(chosen_experts)], chosen_experts[i], self.k_min, self.timeout_after_k_min,
+             self.backward_k_min, self.forward_timeout, self.backward_timeout, input_schema, *flat_inputs_per_expert[i])
             for i in range(len(input))
         )