@@ -66,3 +66,8 @@ def test_compute_expert_scores():
assert torch.allclose(logits[b, e], gx[b, ii[b][e]] + gy[b, jj[b][e]]), "compute_expert_scores returned incorrect score"
finally:
dht.shutdown()
+
+if __name__ == '__main__':
+ test_remote_module_call()
+ test_compute_expert_scores()