|
@@ -177,8 +177,6 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
This function that can recover from individual failures during forward and/or backward passes.
|
|
|
For user-friendly version of this function, use RemoteMixtureOfExperts module.
|
|
|
"""
|
|
|
- MIN_TOTAL_WEIGHT = 1e-3
|
|
|
-
|
|
|
@classmethod
|
|
|
def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
|
|
|
*flat_inputs: torch.Tensor, input_schema, k_min: int, timeout_after_k_min: float, backward_k_min: int,
|
|
@@ -229,7 +227,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
*survived_grad_inputs))
|
|
|
|
|
|
grad_logits = None # TODO
|
|
|
- return (grad_logits, None, *flat_grad_inputs, None, None, None, None, None, None)
|
|
|
+ return grad_logits, None, *flat_grad_inputs, None, None, None, None, None, None
|
|
|
|
|
|
@staticmethod
|
|
|
def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
|