|
@@ -79,12 +79,12 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
flat_inputs_per_expert = tuple(zip(*[tensor.split(1, dim=0) for tensor in nested_flatten(expert_inputs)]))
|
|
|
|
|
|
batch_jobs_args = tuple(
|
|
|
- (expert_logits[i, :len(chosen_experts)], chosen_experts[i], self.k_min, self.timeout_after_k_min,
|
|
|
+ (expert_logits[i, :len(chosen_experts[i])], chosen_experts[i], self.k_min, self.timeout_after_k_min,
|
|
|
self.backward_k_min, self.forward_timeout, self.backward_timeout, input_schema, *flat_inputs_per_expert[i])
|
|
|
for i in range(len(input))
|
|
|
)
|
|
|
|
|
|
- averaged_outputs_flat = map(torch.cat, map_with_parallel_backward(_RemoteMoECall, *batch_jobs_args)
|
|
|
+ averaged_outputs_flat = map(torch.cat, zip(*map_with_parallel_backward(_RemoteMoECall, *batch_jobs_args)))
|
|
|
return nested_pack(averaged_outputs_flat, self.outputs_schema)
|
|
|
|
|
|
def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
|