|
@@ -202,7 +202,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
lambda *tensors: sum(x * weight for x, weight in zip(tensors, alive_expert_probs)), *alive_outputs))
|
|
|
|
|
|
# 3. save individual outputs for backward pass
|
|
|
- ctx.save_for_backward(flat_inputs, expert_logits, alive_ix, alive_expert_probs)
|
|
|
+ ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs)
|
|
|
ctx._alive_contexts = alive_contexts
|
|
|
ctx._backward_k_min = backward_k_min
|
|
|
ctx._backward_timeout = backward_timeout
|
|
@@ -212,7 +212,8 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
@once_differentiable
|
|
|
def backward(cls, ctx, *grad_outputs_flat: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
""" Like normal backward, but we ignore any experts that failed during backward pass """
|
|
|
- flat_inputs, expert_logits, alive_ix, alive_expert_probas = ctx.saved_tensors
|
|
|
+ #TODO add dummy tensor or something else that ensures that backward pass is not omitted even if inputs do not require grad
|
|
|
+ expert_logits, alive_ix, alive_expert_probas = ctx.saved_tensors
|
|
|
alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
|
|
|
|
|
|
jobs = [partial(cls._run_expert_backward, ctx, prob, grad_outputs_flat)
|