Browse Source

debugprint

justheuristic 5 years ago
parent
commit
1e06c47121
2 changed files with 2 additions and 2 deletions
  1. 1 1
      tesseract/client/moe.py
  2. 1 1
      tests/test_moe.py

+ 1 - 1
tesseract/client/moe.py

@@ -197,9 +197,9 @@ class _RemoteMoECall(torch.autograd.Function):
         # 2. compute softmax weights for alive experts and average outputs
         alive_ix = torch.as_tensor(alive_ix, device=expert_logits.device)
         alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
-        print(f'{alive_expert_probs=}')
 
         stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
+        print(f'>> {[outs[0].norm() for outs in alive_outputs]}')
         flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
                                      for stacked_out in stacked_alive_outputs)
 

+ 1 - 1
tests/test_moe.py

@@ -27,7 +27,7 @@ def test_remote_module_call():
 
         # reference outputs: call all experts manually and average their outputs with softmax probabilities
         probs = torch.softmax(logits, 0)
-        print(f'ref {probs=}')
+        print(f'ref {[out.norm() for out in outs]}')
         outs = [expert(xx) for expert in experts[:3]]
         manual_output = sum(p * x for p, x in zip(probs, outs))
         grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)