|
@@ -179,8 +179,8 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
"""
|
|
|
@classmethod
|
|
|
def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
|
|
|
- *flat_inputs: torch.Tensor, input_schema, k_min: int, timeout_after_k_min: float, backward_k_min: int,
|
|
|
- timeout_total: Optional[float], backward_timeout: Optional[float]) -> Tuple[torch.Tensor]:
|
|
|
+ k_min: int, timeout_after_k_min: float, backward_k_min: int, timeout_total: Optional[float],
|
|
|
+ backward_timeout: Optional[float], input_schema, *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
|
|
|
expert_args, expert_kwargs = nested_pack(flat_inputs, structure=input_schema)
|
|
|
assert expert_logits.ndim == 1 and len(expert_logits) == len(experts)
|
|
|
|