|
@@ -20,7 +20,7 @@ def test_remote_module_call():
|
|
|
[(None,), {}], xx)
|
|
|
|
|
|
grad_xx_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
|
|
|
- grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
|
|
|
+ grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), logits, retain_graph=True)
|
|
|
|
|
|
# reference outputs: call all experts manually and average their outputs with softmax probabilities
|
|
|
probs = torch.softmax(logits, 0)
|