ソースを参照

run_server update (#42)

* pool is now daemon

* specify forking

* use spawn instead of fork, add shutdown timeout, fix docstring

* use single pipe for background_server control

* update documentation on run_server

* daemonize connection handlers
justheuristic 5 年 前
コミット
bb7f867ca7

+ 5 - 6
hivemind/runtime/task_pool.py

@@ -19,11 +19,11 @@ from ..utils import SharedFuture
 Task = namedtuple("Task", ("future", "args"))
 
 
-class TaskPoolBase(mp.Process):
+class TaskPoolBase(mp.context.ForkProcess):
     """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """
 
-    def __init__(self, process_func: callable):
-        super().__init__()
+    def __init__(self, process_func: callable, daemon=True):
+        super().__init__(daemon=daemon)
         self.process_func = process_func
         self._priority = mp.Value(ctypes.c_double, 1.0)  # higher priority = the more urgent to process this pool
 
@@ -66,9 +66,8 @@ class TaskPool(TaskPoolBase):
     """
 
     def __init__(self, process_func: callable, max_batch_size: int, min_batch_size=1,
-                 timeout=None, pool_size=None, prefetch_batches=1, uid=None, start=False):
-
-        super().__init__(process_func)
+                 timeout=None, pool_size=None, prefetch_batches=1, uid=None, daemon=True, start=False):
+        super().__init__(process_func, daemon=daemon)
         self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
         self.uid = uid or uuid.uuid4()
         self.prefetch_batches = prefetch_batches

+ 2 - 1
hivemind/server/__init__.py

@@ -98,7 +98,8 @@ class Server(threading.Thread):
         sock.listen()
         sock.settimeout(self.update_period)
 
-        processes = [mp.Process(target=socket_loop, name=f"socket_loop-{i}", args=(sock, self.experts))
+        processes = [mp.context.ForkProcess(
+            target=socket_loop, name=f"socket_loop-{i}", args=(sock, self.experts), daemon=True)
                      for i in range(num_handlers)]
         return processes
 

+ 10 - 13
tests/test_moe.py

@@ -51,7 +51,7 @@ def test_determinism():
     mask = torch.randint(0, 1, (32, 1024))
 
     with background_server(num_experts=1, device='cpu', expert_cls='det_dropout',
-                           no_optimizer=True, no_dht=True) as (localhost, server_port, dht_port):
+                           no_optimizer=True, no_dht=True) as (interface, server_port, dht_port):
         expert = hivemind.RemoteExpert(uid=f'expert.0', port=server_port)
 
         out = expert(xx, mask)
@@ -68,27 +68,24 @@ def test_compute_expert_scores():
     try:
         dht = hivemind.DHTNode(port=hivemind.find_open_port(), start=True)
         moe = hivemind.client.moe.RemoteMixtureOfExperts(
-            dht=dht, in_features=1024, grid_size=[40], k_best=4, k_min=1, timeout_after_k_min=1,
+            dht=dht, in_features=1024, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1,
             uid_prefix='expert')
         gx, gy = torch.randn(4, 5, requires_grad=True), torch.torch.randn(4, 3, requires_grad=True)
         ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
-            [hivemind.RemoteExpert(uid=f'expert.{ii[b][e]}.{jj[b][e]}') for e in range(len(ii[b]))]
-            for b in range(len(ii))
+            [hivemind.RemoteExpert(uid=f'expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}')
+             for expert_i in range(len(ii[batch_i]))]
+            for batch_i in range(len(ii))
         ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
         logits = moe.compute_expert_scores([gx, gy], batch_experts)
         torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
         assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
 
-        for b in range(len(ii)):
-            for e in range(len(ii[b])):
-                assert torch.allclose(logits[b, e], gx[b, ii[b][e]] + gy[b, jj[b][e]]), "compute_expert_scores returned incorrect score"
+        for batch_i in range(len(ii)):
+            for expert_i in range(len(ii[batch_i])):
+                assert torch.allclose(logits[batch_i, expert_i],
+                                      gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]), \
+                    "compute_expert_scores returned incorrect score"
     finally:
         dht.shutdown()
-
-
-if __name__ == '__main__':
-    test_remote_module_call()
-    test_compute_expert_scores()
-    test_determinism()

+ 78 - 43
tests/test_utils/run_server.py

@@ -8,11 +8,32 @@ import hivemind
 from .layers import name_to_block, name_to_input
 
 
-def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
-                      expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None, no_optimizer=False,
-                      no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True, start=False,
-                      UID_DELIMETER=hivemind.DHTNode.UID_DELIMETER, **kwargs) -> hivemind.Server:
-    """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
+def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024,
+                      num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
+                      no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
+                      UID_DELIMETER=hivemind.DHTNode.UID_DELIMETER, start=False, **kwargs) -> hivemind.Server:
+    """
+    Instantiate a server with several identical experts. See argparse comments below for details
+    :param interface: 'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6
+    :param port: main server will listen to this port, default = find open port
+    :param num_experts: run this many identical experts
+    :param expert_cls: expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
+    :param hidden_dim: main dimension for expert_cls
+    :param num_handlers: server will use this many parallel processes to handle incoming requests
+    :param expert_prefix: all expert uids will be {expert_prefix}.{index}
+    :param expert_offset: expert uid will use indices in range(expert_offset, expert_offset + num_experts)
+    :param max_batch_size: total num examples in the same batch will not exceed this value
+    :param device: all experts will use this device in torch notation; default: cuda if available else cpu
+    :param no_optimizer: if specified, all optimizers use learning rate=0
+    :param no_dht: if specified, the server will not be attached to a dht
+    :param initial_peers: a list of peers that will introduce this node to the dht,
+      e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]'), default = no peers
+    :param dht_port:  DHT node will listen on this port, default = find open port
+    :param root_port: if this server does not have initial_peers, it will create a virtual dht node on this port.
+        You can then use this node as initial peer for subsequent servers.
+    :param verbose: whether to print server started / finished / terminated events
+    :param start: if True, starts server right away and returns when server is ready for requests
+    """
     if verbose and len(kwargs) != 0:
         print("Ignored kwargs:", kwargs)
     assert expert_cls in name_to_block
@@ -57,7 +78,7 @@ def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
                                                      )
     # actually start server
     server = hivemind.Server(
-        dht, experts, addr=host, port=port or hivemind.find_open_port(),
+        dht, experts, addr=interface, port=port or hivemind.find_open_port(),
         conn_handler_processes=num_handlers, device=device)
 
     if start:
@@ -69,55 +90,69 @@ def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
 
 
 @contextmanager
-def background_server(*args, verbose=True, **kwargs):
-    """ Runs server in a background process and returns a reference to it. """
-    recv_addr, send_addr = mp.Pipe(duplex=True)
-    trigger_shutdown = mp.Event()
+def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs):
+    """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
+    pipe, runners_pipe = mp.Pipe(duplex=True)
+    runner = mp.get_context("spawn").Process(
+        target=_server_runner, args=(runners_pipe, *args), kwargs=dict(verbose=verbose, **kwargs))
 
-    def server_runner():
+    try:
+        runner.start()
+        yield pipe.recv()  # once the server is ready, runner will send us a tuple(hostname, port, dht port)
+        pipe.send('SHUTDOWN')  # on exit from context, send shutdown signal
+    finally:
         try:
-            server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
-            dht_port = server.dht.port if server.dht is not None else None
-            send_addr.send((server.addr, server.port, dht_port))
-            trigger_shutdown.wait()
+            runner.join(timeout=shutdown_timeout)
         finally:
             if verbose:
-                print("Shutting down server...")
-            trigger_shutdown.set()  # if server failed internally, set the shutdown trigger anyway
-            server.shutdown()
+                print("Server failed to shutdown gracefully, terminating it the hard way...")
+            runner.terminate()
             if verbose:
-                print("Server shut down successfully.")
+                print("Server terminated.")
 
-    try:
-        runner = mp.Process(target=server_runner)
-        runner.start()
-        yield recv_addr.recv()  # yield tuple(hostname, port)
 
+def _server_runner(pipe, *args, verbose, **kwargs):
+    server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
+    try:
+        dht_port = server.dht.port if server.dht is not None else None
+        pipe.send((server.addr, server.port, dht_port))
+        pipe.recv()  # wait for shutdown signal
     finally:
-        trigger_shutdown.set()
-        runner.terminate()
-        runner.join()
+        if verbose:
+            print("Shutting down server...")
+        server.shutdown()
+        if verbose:
+            print("Server shut down successfully.")
 
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('--host', type=str, default='0.0.0.0', required=False)
-    parser.add_argument('--port', type=int, default=None, required=False)
-    parser.add_argument('--num_experts', type=int, default=1, required=False)
-    parser.add_argument('--expert_cls', type=str, default='ffn', required=False)
-    parser.add_argument('--hidden_dim', type=int, default=1024, required=False)
-    parser.add_argument('--num_handlers', type=int, default=None, required=False)
-    parser.add_argument('--expert_prefix', type=str, default='expert', required=False)
-    parser.add_argument('--expert_offset', type=int, default=0, required=False)
-    parser.add_argument('--max_batch_size', type=int, default=16384, required=False)
-    parser.add_argument('--device', type=str, default=None, required=False)
-    parser.add_argument('--no_optimizer', action='store_true')
-    parser.add_argument('--no_dht', action='store_true')
-    parser.add_argument('--initial_peers', type=str, default="[]", required=False)
-    parser.add_argument('--dht_port', type=int, default=None, required=False)
-    parser.add_argument('--root_port', type=int, default=None, required=False)
-
-    parser.add_argument('--increase_file_limit', action='store_true')
+    parser.add_argument('--interface', type=str, default='0.0.0.0', required=False,
+                        help="'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6")
+    parser.add_argument('--port', type=int, default=None, required=False, help="server will listen to this port")
+    parser.add_argument('--num_experts', type=int, default=1, required=False, help="run this many identical experts")
+    parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
+                        help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
+    parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
+    parser.add_argument('--num_handlers', type=int, default=None, required=False,
+                        help='server will use this many processes to handle incoming requests')
+    parser.add_argument('--expert_prefix', type=str, default='expert', required=False,
+                        help='all expert uids will be {expert_prefix}.{index}')
+    parser.add_argument('--expert_offset', type=int, default=0, required=False,
+                        help='expert uid will use indices in range(expert_offset, expert_offset + num_experts)')
+    parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
+                        help='total num examples in the same batch will not exceed this value')
+    parser.add_argument('--device', type=str, default=None, required=False,
+                        help='all experts will use this device in torch notation; default: cuda if available else cpu')
+    parser.add_argument('--no_optimizer', action='store_true', help='if specified, all optimizers use learning rate=0')
+    parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
+    parser.add_argument('--initial_peers', type=str, default="[]", required=False, help='a list of peers that will'
+                        ' introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]')
+    parser.add_argument('--dht_port', type=int, default=None, required=False, help='DHT node will listen on this port')
+    parser.add_argument('--root_port', type=int, default=None, required=False, help='If this server does not have peers'
+                        ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
+    parser.add_argument('--increase_file_limit', action='store_true', help='On *nix, this will increase the max number'
+                        ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
 
     args = vars(parser.parse_args())