Przeglądaj źródła

wip: implement grad wrt logits

justheuristic 5 lat temu
rodzic
commit
c005da2089
1 zmienionych plików z 1 dodań i 1 usunięć
  1. 1 1
      tesseract/client/moe.py

+ 1 - 1
tesseract/client/moe.py

@@ -201,7 +201,7 @@ class _RemoteMoECall(torch.autograd.Function):
         stacked_alive_outputs = tuple(map(torch.stack, alive_outputs))
         flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
                                      for stacked_out in stacked_alive_outputs)
-        print(flat_average_outputs)
+        print('!!!!', flat_average_outputs, flush=True)
 
         # 3. save individual outputs for backward pass
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)