Browse Source

list -> tensor

justheuristic 5 years ago
parent
commit
5cbcf79b00
1 changed files with 3 additions and 2 deletions
  1. 3 2
      tesseract/client/moe.py

+ 3 - 2
tesseract/client/moe.py

@@ -202,7 +202,7 @@ class _RemoteMoECall(torch.autograd.Function):
             lambda *tensors: sum(x * weight for x, weight in zip(tensors, alive_expert_probs)), *alive_outputs))
 
         # 3. save individual outputs for backward pass
-        ctx.save_for_backward(flat_inputs, expert_logits, alive_ix, alive_expert_probs)
+        ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs)
         ctx._alive_contexts = alive_contexts
         ctx._backward_k_min = backward_k_min
         ctx._backward_timeout = backward_timeout
@@ -212,7 +212,8 @@ class _RemoteMoECall(torch.autograd.Function):
     @once_differentiable
     def backward(cls, ctx, *grad_outputs_flat: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
         """ Like normal backward, but we ignore any experts that failed during backward pass """
-        flat_inputs, expert_logits, alive_ix, alive_expert_probas  = ctx.saved_tensors
+        #TODO add dummy tensor or something else that ensures that backward pass is not omitted even if inputs do not require grad
+        expert_logits, alive_ix, alive_expert_probas = ctx.saved_tensors
         alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
 
         jobs = [partial(cls._run_expert_backward, ctx, prob, grad_outputs_flat)