Browse Source

wip: implement grad wrt logits

justheuristic 5 years ago
parent
commit
785b029d48
1 changed files with 1 additions and 2 deletions
  1. 1 2
      tesseract/client/moe.py

+ 1 - 2
tesseract/client/moe.py

@@ -201,7 +201,6 @@ class _RemoteMoECall(torch.autograd.Function):
         stacked_alive_outputs = tuple(map(torch.stack, alive_outputs))
         stacked_alive_outputs = tuple(map(torch.stack, alive_outputs))
         flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
         flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
                                      for stacked_out in stacked_alive_outputs)
                                      for stacked_out in stacked_alive_outputs)
-        print('>>>>>>>>', flat_average_outputs)
 
 
         # 3. save individual outputs for backward pass
         # 3. save individual outputs for backward pass
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
@@ -253,4 +252,4 @@ class _RemoteMoECall(torch.autograd.Function):
 
 
 
 
 def dot_along_first_axis(x, y):
 def dot_along_first_axis(x, y):
-    return (x.view(-1, *[1] * (y.ndim - 1))).sum(0)
+    return (x.view(-1, *[1] * (y.ndim - 1)) * y).sum(0)