Browse Source

debugprint

justheuristic 5 years ago
parent
commit
b6ad7428a5
1 changed files with 1 additions and 1 deletions
  1. 1 1
      tests/test_moe.py

+ 1 - 1
tests/test_moe.py

@@ -27,8 +27,8 @@ def test_remote_module_call():
 
         # reference outputs: call all experts manually and average their outputs with softmax probabilities
         probs = torch.softmax(logits, 0)
-        print(f'ref {[out.norm() for out in outs]}')
         outs = [expert(xx) for expert in experts[:3]]
+        print(f'ref {[out.norm() for out in outs]}')
         manual_output = sum(p * x for p, x in zip(probs, outs))
         grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
         grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)