justheuristic пре 5 година
родитељ
комит
60af3952c9
1 измењених фајлова са 2 додато и 2 уклоњено
  1. 2 2
      tesseract/client/moe.py

+ 2 - 2
tesseract/client/moe.py

@@ -179,8 +179,8 @@ class _RemoteMoECall(torch.autograd.Function):
     """
     @classmethod
     def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
-                *flat_inputs: torch.Tensor, input_schema, k_min: int, timeout_after_k_min: float, backward_k_min: int,
-                timeout_total: Optional[float], backward_timeout: Optional[float]) -> Tuple[torch.Tensor]:
+                k_min: int, timeout_after_k_min: float, backward_k_min: int, timeout_total: Optional[float],
+                backward_timeout: Optional[float], input_schema, *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
         expert_args, expert_kwargs = nested_pack(flat_inputs, structure=input_schema)
         assert expert_logits.ndim == 1 and len(expert_logits) == len(experts)