|
@@ -28,8 +28,8 @@ def test_remote_module_call():
|
|
|
# reference outputs: call all experts manually and average their outputs with softmax probabilities
|
|
|
probs = torch.softmax(logits, 0)
|
|
|
outs = [expert(xx) for expert in experts[:3]]
|
|
|
- print(f'ref {[out.norm() for out in outs]}')
|
|
|
manual_output = sum(p * x for p, x in zip(probs, outs))
|
|
|
+ print(f'ref {[manual_output.min(), manual_output.max(), manual_output.norm()]}')
|
|
|
grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
|
|
|
grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
|
|
|
grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)
|