Bläddra i källkod

stack outputs in one tensor

justheuristic 5 år sedan
förälder
incheckning
94322034f3
1 ändrade filer med 2 tillägg och 2 borttagningar
  1. 2 2
      tesseract/client/moe.py

+ 2 - 2
tesseract/client/moe.py

@@ -79,12 +79,12 @@ class RemoteMixtureOfExperts(nn.Module):
         flat_inputs_per_expert = tuple(zip(*[tensor.split(1, dim=0) for tensor in nested_flatten(expert_inputs)]))
 
         batch_jobs_args = tuple(
-            (expert_logits[i, :len(chosen_experts)], chosen_experts[i], self.k_min, self.timeout_after_k_min,
+            (expert_logits[i, :len(chosen_experts[i])], chosen_experts[i], self.k_min, self.timeout_after_k_min,
              self.backward_k_min, self.forward_timeout, self.backward_timeout, input_schema, *flat_inputs_per_expert[i])
             for i in range(len(input))
         )
 
-        averaged_outputs_flat = map(torch.cat, map_with_parallel_backward(_RemoteMoECall, *batch_jobs_args)
+        averaged_outputs_flat = map(torch.cat, zip(*map_with_parallel_backward(_RemoteMoECall, *batch_jobs_args)))
         return nested_pack(averaged_outputs_flat, self.outputs_schema)
 
     def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]: