فهرست منبع

grad logits wrt actual logits

justheuristic 5 سال پیش
والد
کامیت
b20f3ee985
1فایلهای تغییر یافته به همراه1 افزوده شده و 1 حذف شده
  1. 1 1
      tests/test_moe.py

+ 1 - 1
tests/test_moe.py

@@ -20,7 +20,7 @@ def test_remote_module_call():
         [(None,), {}], xx)
 
     grad_xx_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
-    grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
+    grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), logits, retain_graph=True)
 
     # reference outputs: call all experts manually and average their outputs with softmax probabilities
     probs = torch.softmax(logits, 0)