瀏覽代碼

background_server is now a contextmanager

justheuristic 5 年之前
父節點
當前提交
7bdb5d919b
共有 1 個文件被更改,包括 9 次插入9 次删除
  1. 9 9
      tests/test_moe.py

+ 9 - 9
tests/test_moe.py

@@ -5,19 +5,19 @@ from test_utils.run_server import background_server
 
 def test_remote_module_call():
     """ Check that remote_module_call returns correct outputs and gradients if called directly """
+    num_experts = 8
+    k_min = 1
+    timeout_after_k_min = None
+    backward_k_min = 1
+    timeout_total = None
+    backward_timeout = None
+
     xx = torch.randn(32, 1024, requires_grad=True)
     logits = torch.randn(3, requires_grad=True)
     random_proj = torch.randn_like(xx)
-    num_experts = 8
-
-    with background_server(num_experts=num_experts, no_optimizer=True, no_network=True, verbose=True) as server:
-        experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=server.port) for i in range(num_experts)]
 
-        k_min = 1
-        timeout_after_k_min = None
-        backward_k_min = 1
-        timeout_total = None
-        backward_timeout = None
+    with background_server(num_experts=num_experts, no_optimizer=True, no_network=True) as localhost, server_port:
+        experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=server_port) for i in range(num_experts)]
         moe_output, = tesseract.client.moe._RemoteMoECall.apply(
             logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
             [(None,), {}], xx)