Ver código fonte

wip: implement grad wrt logits

justheuristic 5 anos atrás
pai
commit
785e115d89
1 arquivos alterados com 2 adições e 2 exclusões
  1. 2 2
      tesseract/client/moe.py

+ 2 - 2
tesseract/client/moe.py

@@ -199,8 +199,8 @@ class _RemoteMoECall(torch.autograd.Function):
         alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
 
         stacked_alive_outputs = tuple(map(torch.stack, alive_outputs))
-        flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
-                                     for stacked_out in stacked_alive_outputs)
+        flat_average_outputs = sum(dot_along_first_axis(alive_expert_probs, stacked_out)
+                                   for stacked_out in stacked_alive_outputs)
 
         # 3. save individual outputs for backward pass
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)