Explorar el Código

wip: implement grad wrt logits

justheuristic hace 5 años
padre
commit
676066baed
Se han modificado 1 ficheros con 1 adiciones y 0 borrados
  1. 1 0
      tesseract/client/moe.py

+ 1 - 0
tesseract/client/moe.py

@@ -201,6 +201,7 @@ class _RemoteMoECall(torch.autograd.Function):
         stacked_alive_outputs = tuple(map(torch.stack, alive_outputs))
         flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
                                      for stacked_out in stacked_alive_outputs)
+        print(flat_average_outputs)
 
         # 3. save individual outputs for backward pass
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)