|
@@ -212,7 +212,6 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
@once_differentiable
|
|
|
def backward(cls, ctx, *grad_outputs_flat: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
""" Like normal backward, but we ignore any experts that failed during backward pass """
|
|
|
- #TODO add dummy tensor or something else that ensures that backward pass is not omitted even if inputs do not require grad
|
|
|
expert_logits, alive_ix, alive_expert_probas = ctx.saved_tensors
|
|
|
alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
|
|
|
|