소스 검색

make run_server executable via python -m

justheuristic 5 년 전
부모
커밋
3696bb065c
2개의 변경된 파일40개의 추가작업 그리고 4개의 파일을 삭제
  1. 2 1
      tests/test_moe.py
  2. 38 3
      tests/test_utils/run_server.py

+ 2 - 1
tests/test_moe.py

@@ -18,7 +18,8 @@ def test_remote_module_call():
     logits = torch.randn(3, requires_grad=True)
     random_proj = torch.randn_like(xx)
 
-    with background_server(num_experts=num_experts, no_optimizer=True, no_network=True) as (localhost, server_port):
+    with background_server(num_experts=num_experts,  device='cpu',
+                           no_optimizer=True, no_network=True) as (localhost, server_port):
         experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=server_port) for i in range(num_experts)]
         moe_output, = tesseract.client.moe._RemoteMoECall.apply(
             logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,

+ 38 - 3
tests/test_utils/run_server.py

@@ -1,5 +1,7 @@
+import resource
 from contextlib import contextmanager
 import multiprocessing as mp
+import argparse
 
 import torch
 import tesseract
@@ -7,7 +9,7 @@ from .layers import name_to_block
 
 
 def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
-                      expert_prefix='expert.', expert_offset=0, max_batch_size=16384, device='cpu', no_optimizer=False,
+                      expert_prefix='expert.', expert_offset=0, max_batch_size=16384, device=None, no_optimizer=False,
                       no_network=False, initial_peers=(), network_port=None, verbose=True, start=True, **kwargs
                       ) -> tesseract.TesseractServer:
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
@@ -80,5 +82,38 @@ def background_server(*args, verbose=True, **kwargs):
 
 
 if __name__ == '__main__':
-    with background_server() as (host, port):
-        mp.Event().wait()  # aka fall asleep forever
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--host', type=str, default='0.0.0.0', required=False)
+    parser.add_argument('--port', type=int, default=None, required=False)
+    parser.add_argument('--num_experts', type=int, default=1, required=False)
+    parser.add_argument('--expert_cls', type=str, default='ffn', required=False)
+    parser.add_argument('--hidden_dim', type=int, default=1024, required=False)
+    parser.add_argument('--num_handlers', type=int, default=None, required=False)
+    parser.add_argument('--expert_prefix', type=str, default='expert.', required=False)
+    parser.add_argument('--expert_offset', type=int, default=0, required=False)
+    parser.add_argument('--max_batch_size', type=int, default=16384, required=False)
+    parser.add_argument('--device', type=str, default=None, required=False)
+    parser.add_argument('--no_optimizer', action='store_true')
+    parser.add_argument('--no_network', action='store_true')
+    parser.add_argument('--initial_peers', type=str, default="[]", required=False)
+    parser.add_argument('--network_port', type=int, default=None, required=False)
+
+    parser.add_argument('--increase_file_limit', action='store_true')
+
+    args = vars(parser.parse_args())
+
+    if args.pop('increase_file_limit'):
+        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+        try:
+            print("Setting open file limit to soft={}, hard={}".format(max(soft, 2 ** 15), max(hard, 2 ** 15)))
+            resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 2 ** 15), max(hard, 2 ** 15)))
+        except:
+            print("Could not increase open file limit, currently at soft={}, hard={}".format(soft, hard))
+
+    args['initial_peers'] = eval(args['initial_peers'])
+
+    try:
+        server = make_dummy_server(**args, start=False, verbose=True)
+        server.join()
+    finally:
+        server.shutdown()