|
@@ -8,11 +8,32 @@ import hivemind
|
|
from .layers import name_to_block, name_to_input
|
|
from .layers import name_to_block, name_to_input
|
|
|
|
|
|
|
|
|
|
-def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
|
|
|
|
- expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None, no_optimizer=False,
|
|
|
|
- no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True, start=False,
|
|
|
|
- UID_DELIMETER=hivemind.DHTNode.UID_DELIMETER, **kwargs) -> hivemind.Server:
|
|
|
|
- """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
|
|
|
|
|
|
+def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024,
|
|
|
|
+ num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
|
|
|
|
+ no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
|
|
|
|
+ UID_DELIMETER=hivemind.DHTNode.UID_DELIMETER, start=False, **kwargs) -> hivemind.Server:
|
|
|
|
+ """
|
|
|
|
+ Instantiate a server with several identical experts. See argparse comments below for details
|
|
|
|
+ :param interface: 'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6
|
|
|
|
+ :param port: main server will listen to this port, default = find open port
|
|
|
|
+ :param num_experts: run this many identical experts
|
|
|
|
+ :param expert_cls: expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
|
|
|
|
+ :param hidden_dim: main dimension for expert_cls
|
|
|
|
+ :param num_handlers: server will use this many parallel processes to handle incoming requests
|
|
|
|
+ :param expert_prefix: all expert uids will be {expert_prefix}.{index}
|
|
|
|
+ :param expert_offset: expert uid will use indices in range(expert_offset, expert_offset + num_experts)
|
|
|
|
+ :param max_batch_size: total num examples in the same batch will not exceed this value
|
|
|
|
+ :param device: all experts will use this device in torch notation; default: cuda if available else cpu
|
|
|
|
+ :param no_optimizer: if specified, all optimizers use learning rate=0
|
|
|
|
+ :param no_dht: if specified, the server will not be attached to a dht
|
|
|
|
+ :param initial_peers: a list of peers that will introduce this node to the dht,
|
|
|
|
+ e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]'), default = no peers
|
|
|
|
+ :param dht_port: DHT node will listen on this port, default = find open port
|
|
|
|
+ :param root_port: if this server does not have initial_peers, it will create a virtual dht node on this port.
|
|
|
|
+ You can then use this node as initial peer for subsequent servers.
|
|
|
|
+ :param verbose: whether to print server started / finished / terminated events
|
|
|
|
+ :param start: if True, starts server right away and returns when server is ready for requests
|
|
|
|
+ """
|
|
if verbose and len(kwargs) != 0:
|
|
if verbose and len(kwargs) != 0:
|
|
print("Ignored kwargs:", kwargs)
|
|
print("Ignored kwargs:", kwargs)
|
|
assert expert_cls in name_to_block
|
|
assert expert_cls in name_to_block
|
|
@@ -57,7 +78,7 @@ def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
|
|
)
|
|
)
|
|
# actually start server
|
|
# actually start server
|
|
server = hivemind.Server(
|
|
server = hivemind.Server(
|
|
- dht, experts, addr=host, port=port or hivemind.find_open_port(),
|
|
|
|
|
|
+ dht, experts, addr=interface, port=port or hivemind.find_open_port(),
|
|
conn_handler_processes=num_handlers, device=device)
|
|
conn_handler_processes=num_handlers, device=device)
|
|
|
|
|
|
if start:
|
|
if start:
|
|
@@ -69,55 +90,69 @@ def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
@contextmanager
|
|
-def background_server(*args, verbose=True, **kwargs):
|
|
|
|
- """ Runs server in a background process and returns a reference to it. """
|
|
|
|
- recv_addr, send_addr = mp.Pipe(duplex=True)
|
|
|
|
- trigger_shutdown = mp.Event()
|
|
|
|
|
|
+def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs):
|
|
|
|
+ """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
|
|
|
|
+ pipe, runners_pipe = mp.Pipe(duplex=True)
|
|
|
|
+ runner = mp.get_context("spawn").Process(
|
|
|
|
+ target=_server_runner, args=(runners_pipe, *args), kwargs=dict(verbose=verbose, **kwargs))
|
|
|
|
|
|
- def server_runner():
|
|
|
|
|
|
+ try:
|
|
|
|
+ runner.start()
|
|
|
|
+ yield pipe.recv() # once the server is ready, runner will send us a tuple(hostname, port, dht port)
|
|
|
|
+ pipe.send('SHUTDOWN') # on exit from context, send shutdown signal
|
|
|
|
+ finally:
|
|
try:
|
|
try:
|
|
- server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
|
|
|
|
- dht_port = server.dht.port if server.dht is not None else None
|
|
|
|
- send_addr.send((server.addr, server.port, dht_port))
|
|
|
|
- trigger_shutdown.wait()
|
|
|
|
|
|
+ runner.join(timeout=shutdown_timeout)
|
|
finally:
|
|
finally:
|
|
if verbose:
|
|
if verbose:
|
|
- print("Shutting down server...")
|
|
|
|
- trigger_shutdown.set() # if server failed internally, set the shutdown trigger anyway
|
|
|
|
- server.shutdown()
|
|
|
|
|
|
+ print("Server failed to shutdown gracefully, terminating it the hard way...")
|
|
|
|
+ runner.terminate()
|
|
if verbose:
|
|
if verbose:
|
|
- print("Server shut down successfully.")
|
|
|
|
|
|
+ print("Server terminated.")
|
|
|
|
|
|
- try:
|
|
|
|
- runner = mp.Process(target=server_runner)
|
|
|
|
- runner.start()
|
|
|
|
- yield recv_addr.recv() # yield tuple(hostname, port)
|
|
|
|
|
|
|
|
|
|
+def _server_runner(pipe, *args, verbose, **kwargs):
|
|
|
|
+ server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
|
|
|
|
+ try:
|
|
|
|
+ dht_port = server.dht.port if server.dht is not None else None
|
|
|
|
+ pipe.send((server.addr, server.port, dht_port))
|
|
|
|
+ pipe.recv() # wait for shutdown signal
|
|
finally:
|
|
finally:
|
|
- trigger_shutdown.set()
|
|
|
|
- runner.terminate()
|
|
|
|
- runner.join()
|
|
|
|
|
|
+ if verbose:
|
|
|
|
+ print("Shutting down server...")
|
|
|
|
+ server.shutdown()
|
|
|
|
+ if verbose:
|
|
|
|
+ print("Server shut down successfully.")
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser = argparse.ArgumentParser()
|
|
- parser.add_argument('--host', type=str, default='0.0.0.0', required=False)
|
|
|
|
- parser.add_argument('--port', type=int, default=None, required=False)
|
|
|
|
- parser.add_argument('--num_experts', type=int, default=1, required=False)
|
|
|
|
- parser.add_argument('--expert_cls', type=str, default='ffn', required=False)
|
|
|
|
- parser.add_argument('--hidden_dim', type=int, default=1024, required=False)
|
|
|
|
- parser.add_argument('--num_handlers', type=int, default=None, required=False)
|
|
|
|
- parser.add_argument('--expert_prefix', type=str, default='expert', required=False)
|
|
|
|
- parser.add_argument('--expert_offset', type=int, default=0, required=False)
|
|
|
|
- parser.add_argument('--max_batch_size', type=int, default=16384, required=False)
|
|
|
|
- parser.add_argument('--device', type=str, default=None, required=False)
|
|
|
|
- parser.add_argument('--no_optimizer', action='store_true')
|
|
|
|
- parser.add_argument('--no_dht', action='store_true')
|
|
|
|
- parser.add_argument('--initial_peers', type=str, default="[]", required=False)
|
|
|
|
- parser.add_argument('--dht_port', type=int, default=None, required=False)
|
|
|
|
- parser.add_argument('--root_port', type=int, default=None, required=False)
|
|
|
|
-
|
|
|
|
- parser.add_argument('--increase_file_limit', action='store_true')
|
|
|
|
|
|
+ parser.add_argument('--interface', type=str, default='0.0.0.0', required=False,
|
|
|
|
+ help="'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6")
|
|
|
|
+ parser.add_argument('--port', type=int, default=None, required=False, help="server will listen to this port")
|
|
|
|
+ parser.add_argument('--num_experts', type=int, default=1, required=False, help="run this many identical experts")
|
|
|
|
+ parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
|
|
|
|
+ help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
|
|
|
|
+ parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
|
|
|
|
+ parser.add_argument('--num_handlers', type=int, default=None, required=False,
|
|
|
|
+ help='server will use this many processes to handle incoming requests')
|
|
|
|
+ parser.add_argument('--expert_prefix', type=str, default='expert', required=False,
|
|
|
|
+ help='all expert uids will be {expert_prefix}.{index}')
|
|
|
|
+ parser.add_argument('--expert_offset', type=int, default=0, required=False,
|
|
|
|
+ help='expert uid will use indices in range(expert_offset, expert_offset + num_experts)')
|
|
|
|
+ parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
|
|
|
|
+ help='total num examples in the same batch will not exceed this value')
|
|
|
|
+ parser.add_argument('--device', type=str, default=None, required=False,
|
|
|
|
+ help='all experts will use this device in torch notation; default: cuda if available else cpu')
|
|
|
|
+ parser.add_argument('--no_optimizer', action='store_true', help='if specified, all optimizers use learning rate=0')
|
|
|
|
+ parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
|
|
|
|
+ parser.add_argument('--initial_peers', type=str, default="[]", required=False, help='a list of peers that will'
|
|
|
|
+ ' introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]')
|
|
|
|
+ parser.add_argument('--dht_port', type=int, default=None, required=False, help='DHT node will listen on this port')
|
|
|
|
+ parser.add_argument('--root_port', type=int, default=None, required=False, help='If this server does not have peers'
|
|
|
|
+ ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
|
|
|
|
+ parser.add_argument('--increase_file_limit', action='store_true', help='On *nix, this will increase the max number'
|
|
|
|
+ ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
|
|
|
|
|
|
args = vars(parser.parse_args())
|
|
args = vars(parser.parse_args())
|
|
|
|
|