Bläddra i källkod

wip: implement grad wrt logits

justheuristic 5 år sedan
förälder
incheckning
676066baed
1 ändrade filer med 1 tillägg och 0 borttagningar
  1. 1 0
      tesseract/client/moe.py

+ 1 - 0
tesseract/client/moe.py

@@ -201,6 +201,7 @@ class _RemoteMoECall(torch.autograd.Function):
         stacked_alive_outputs = tuple(map(torch.stack, alive_outputs))
         flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
                                      for stacked_out in stacked_alive_outputs)
+        print(flat_average_outputs)
 
         # 3. save individual outputs for backward pass
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)