|
@@ -45,24 +45,24 @@ def test_remote_module_call():
|
|
|
|
|
|
def test_compute_expert_scores():
|
|
|
with background_server(device='cpu') as (server_addr, server_port, network_port):
|
|
|
- dht = tesseract.TesseractNetwork(('localhost', network_port), port=tesseract.find_open_port(), start=True)
|
|
|
- moe = tesseract.client.moe.RemoteMixtureOfExperts(
|
|
|
- network=dht, in_features=1024, grid_size=[40], k_best=4, k_min=1, timeout_after_k_min=1,
|
|
|
- uid_prefix='expert')
|
|
|
- gx, gy = torch.randn(4, 5, requires_grad=True), torch.torch.randn(4, 3, requires_grad=True)
|
|
|
- ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
|
|
|
- jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
|
|
|
- grid_scores = [gx, gy]
|
|
|
- batch_experts = [
|
|
|
- [tesseract.RemoteExpert(uid=f'expert.{ii[b][e]}.{jj[b][e]}') for e in range(len(ii[b]))]
|
|
|
- for b in range(len(ii))
|
|
|
- ] # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
|
|
|
- logits = moe.compute_expert_scores([gx, gy], batch_experts)
|
|
|
- torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
|
|
|
- assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores needs to backprop wrt grid scores"
|
|
|
+ try:
|
|
|
+ dht = tesseract.TesseractNetwork(('localhost', network_port), port=tesseract.find_open_port(), start=True)
|
|
|
+ moe = tesseract.client.moe.RemoteMixtureOfExperts(
|
|
|
+ network=dht, in_features=1024, grid_size=[40], k_best=4, k_min=1, timeout_after_k_min=1,
|
|
|
+ uid_prefix='expert')
|
|
|
+ gx, gy = torch.randn(4, 5, requires_grad=True), torch.torch.randn(4, 3, requires_grad=True)
|
|
|
+ ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
|
|
|
+ jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
|
|
|
+ batch_experts = [
|
|
|
+ [tesseract.RemoteExpert(uid=f'expert.{ii[b][e]}.{jj[b][e]}') for e in range(len(ii[b]))]
|
|
|
+ for b in range(len(ii))
|
|
|
+ ] # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
|
|
|
+ logits = moe.compute_expert_scores([gx, gy], batch_experts)
|
|
|
+ torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
|
|
|
+ assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
|
|
|
|
|
|
- for b in range(len(ii)):
|
|
|
- for e in range(len(ii[b])):
|
|
|
- assert torch.allclose(logits[b, e], gx[b, ii[b][e]] + gy[b, jj[b][e]]), "compute_expert_scores returned incorrect score"
|
|
|
-
|
|
|
- dht.shutdown()
|
|
|
+ for b in range(len(ii)):
|
|
|
+ for e in range(len(ii[b])):
|
|
|
+ assert torch.allclose(logits[b, e], gx[b, ii[b][e]] + gy[b, jj[b][e]]), "compute_expert_scores returned incorrect score"
|
|
|
+ finally:
|
|
|
+ dht.shutdown()
|