瀏覽代碼

debugprint

justheuristic 5 年之前
父節點
當前提交
0875f73522
共有 2 個文件被更改,包括 4 次插入2 次删除
  1. 3 1
      tesseract/client/moe.py
  2. 1 1
      tests/test_moe.py

+ 3 - 1
tesseract/client/moe.py

@@ -199,10 +199,12 @@ class _RemoteMoECall(torch.autograd.Function):
         alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
 
         stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
-        print(f'>> {[outs[0].norm() for outs in alive_outputs]}')
+
         flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
                                      for stacked_out in stacked_alive_outputs)
 
+        print(f'ours {[flat_average_outputs[0].min(), flat_average_outputs[0].max(), flat_average_outputs[0].norm()]}')
+
         # 3. save individual outputs for backward pass
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
         ctx._alive_contexts = alive_contexts

+ 1 - 1
tests/test_moe.py

@@ -28,8 +28,8 @@ def test_remote_module_call():
         # reference outputs: call all experts manually and average their outputs with softmax probabilities
         probs = torch.softmax(logits, 0)
         outs = [expert(xx) for expert in experts[:3]]
-        print(f'ref {[out.norm() for out in outs]}')
         manual_output = sum(p * x for p, x in zip(probs, outs))
+        print(f'ref {[manual_output.min(), manual_output.max(), manual_output.norm()]}')
         grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
         grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
         grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)