浏览代码

safer shutdown order

justheuristic 5 年之前
父节点
当前提交
6605b00d05
共有 2 个文件被更改,包括 25 次插入8 次删除
  1. 3 1
      tesseract/server/__init__.py
  2. 22 7
      tests/test_utils/run_server.py

+ 3 - 1
tesseract/server/__init__.py

@@ -111,10 +111,12 @@ class TesseractServer(threading.Thread):
         self.ready.clear()
         self.ready.clear()
         for process in self.conn_handlers:
         for process in self.conn_handlers:
             process.terminate()
             process.terminate()
-        self.runtime.shutdown()
+
         if self.network is not None:
         if self.network is not None:
             self.network.shutdown()
             self.network.shutdown()
 
 
+        self.runtime.shutdown()
+
 
 
 def socket_loop(sock, experts):
 def socket_loop(sock, experts):
     """ catch connections, send tasks to processing, respond with results """
     """ catch connections, send tasks to processing, respond with results """

+ 22 - 7
tests/test_utils/run_server.py

@@ -3,17 +3,17 @@ import torch
 import tesseract
 import tesseract
 from .layers import name_to_block
 from .layers import name_to_block
 from contextlib import contextmanager
 from contextlib import contextmanager
+import multiprocessing as mp
 
 
 
 
-@contextmanager
-def background_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
+def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
                       expert_prefix='expert.', expert_offset=0, max_batch_size=16384, device=None, no_optimizer=False,
                       expert_prefix='expert.', expert_offset=0, max_batch_size=16384, device=None, no_optimizer=False,
-                      no_network=False, initial_peers=(), network_port=None, verbose=False, **kwargs
+                      no_network=False, initial_peers=(), network_port=None, verbose=False, start=True, **kwargs
                       ) -> tesseract.TesseractServer:
                       ) -> tesseract.TesseractServer:
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     if verbose and len(kwargs) == 0:
     if verbose and len(kwargs) == 0:
         print("Ignored kwargs:", kwargs)
         print("Ignored kwargs:", kwargs)
-    expert_cls in name_to_block
+    assert expert_cls in name_to_block
     num_handlers = num_handlers if num_handlers is not None else num_experts * 8
     num_handlers = num_handlers if num_handlers is not None else num_experts * 8
     device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
     device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
 
 
@@ -21,7 +21,8 @@ def background_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
     network = None
     network = None
     if not no_network:
     if not no_network:
         initial_peers = eval(initial_peers)
         initial_peers = eval(initial_peers)
-        network = tesseract.TesseractNetwork(*initial_peers, port=network_port or tesseract.find_open_port(), start=True)
+        network = tesseract.TesseractNetwork(*initial_peers, port=network_port or tesseract.find_open_port(),
+                                             start=True)
         if verbose:
         if verbose:
             print("Parsed initial peers:", initial_peers)
             print("Parsed initial peers:", initial_peers)
             print(f"Running network node on port {network.port}")
             print(f"Running network node on port {network.port}")
@@ -37,19 +38,33 @@ def background_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
                                                       outputs_schema=tesseract.BatchTensorProto(hidden_dim),
                                                       outputs_schema=tesseract.BatchTensorProto(hidden_dim),
                                                       max_batch_size=max_batch_size,
                                                       max_batch_size=max_batch_size,
                                                       )
                                                       )
-    # start server
+    # actually start server
     server = tesseract.TesseractServer(
     server = tesseract.TesseractServer(
         network, experts, addr=host, port=port or tesseract.find_open_port(),
         network, experts, addr=host, port=port or tesseract.find_open_port(),
         conn_handler_processes=num_handlers, device=device)
         conn_handler_processes=num_handlers, device=device)
-    try:
+
+    if start:
         server.run_in_background(await_ready=True)
         server.run_in_background(await_ready=True)
         if verbose:
         if verbose:
             print(f"Running server at {server.addr}:{server.port}")
             print(f"Running server at {server.addr}:{server.port}")
             print(f"Active experts of type {expert_cls}: {list(experts.keys())}")
             print(f"Active experts of type {expert_cls}: {list(experts.keys())}")
+    return server
+
+
+def background_server(*args, verbose=True, **kwargs):
+    """ Runs server in a background process and returns a reference to it. """
+    try:
+        server = make_dummy_server(*args, verbose=verbose, start=False, **kwargs)
+
+        runner = mp.Process(target=lambda: (server.start(), server.join()))
+        runner.start()
+        server.ready.wait()
         yield server
         yield server
+        runner.join()
     finally:
     finally:
         if verbose:
         if verbose:
             print("Shutting down server...")
             print("Shutting down server...")
         server.shutdown()
         server.shutdown()
+        runner.terminate()
         if verbose:
         if verbose:
             print("Server shut down successfully.")
             print("Server shut down successfully.")