justheuristic 5 vuotta sitten
vanhempi
commit
b6d05c6190
2 muutettua tiedostoa jossa 3 lisäystä ja 0 poistoa
  1. 2 0
      tesseract/client/moe.py
  2. 1 0
      tests/test_moe.py

+ 2 - 0
tesseract/client/moe.py

@@ -197,10 +197,12 @@ class _RemoteMoECall(torch.autograd.Function):
         # 2. compute softmax weights for alive experts and average outputs
         alive_ix = torch.as_tensor(alive_ix, device=expert_logits.device)
         alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
+        print(f'{alive_expert_probs=}')
 
         stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
         flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
                                      for stacked_out in stacked_alive_outputs)
+
         # 3. save individual outputs for backward pass
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
         ctx._alive_contexts = alive_contexts

+ 1 - 0
tests/test_moe.py

@@ -27,6 +27,7 @@ def test_remote_module_call():
 
         # reference outputs: call all experts manually and average their outputs with softmax probabilities
         probs = torch.softmax(logits, 0)
+        print(f'ref {probs=}')
         outs = [expert(xx) for expert in experts[:3]]
         manual_output = sum(p * x for p, x in zip(probs, outs))
         grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)