Browse Source

background_server is now a contextmanager

justheuristic 5 years ago
parent
commit
7bdb5d919b
1 changed files with 9 additions and 9 deletions
  1. 9 9
      tests/test_moe.py

+ 9 - 9
tests/test_moe.py

@@ -5,19 +5,19 @@ from test_utils.run_server import background_server
 
 def test_remote_module_call():
     """ Check that remote_module_call returns correct outputs and gradients if called directly """
+    num_experts = 8
+    k_min = 1
+    timeout_after_k_min = None
+    backward_k_min = 1
+    timeout_total = None
+    backward_timeout = None
+
     xx = torch.randn(32, 1024, requires_grad=True)
     logits = torch.randn(3, requires_grad=True)
     random_proj = torch.randn_like(xx)
-    num_experts = 8
-
-    with background_server(num_experts=num_experts, no_optimizer=True, no_network=True, verbose=True) as server:
-        experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=server.port) for i in range(num_experts)]
 
-        k_min = 1
-        timeout_after_k_min = None
-        backward_k_min = 1
-        timeout_total = None
-        backward_timeout = None
+    with background_server(num_experts=num_experts, no_optimizer=True, no_network=True) as localhost, server_port:
+        experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=server_port) for i in range(num_experts)]
         moe_output, = tesseract.client.moe._RemoteMoECall.apply(
             logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
             [(None,), {}], xx)