run_server.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import argparse
  2. import multiprocessing as mp
  3. from contextlib import contextmanager
  4. import resource
  5. from typing import Tuple
  6. import torch
  7. import hivemind
  8. from test_utils.layers import name_to_block, name_to_input
  9. def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hidden_dim=1024,
  10. num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
  11. no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
  12. start=False, **kwargs) -> hivemind.Server:
  13. """
  14. Instantiate a server with several identical experts. See argparse comments below for details
  15. :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
  16. :param num_experts: run this many identical experts
  17. :param expert_cls: expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
  18. :param hidden_dim: main dimension for expert_cls
  19. :param num_handlers: server will use this many parallel processes to handle incoming requests
  20. :param expert_prefix: all expert uids will be {expert_prefix}.{index}
  21. :param expert_offset: expert uid will use indices in range(expert_offset, expert_offset + num_experts)
  22. :param max_batch_size: total num examples in the same batch will not exceed this value
  23. :param device: all experts will use this device in torch notation; default: cuda if available else cpu
  24. :param no_optimizer: if specified, all optimizers use learning rate=0
  25. :param no_dht: if specified, the server will not be attached to a dht
  26. :param initial_peers: a list of peers that will introduce this node to the dht,
  27. e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]'), default = no peers
  28. :param dht_port: DHT node will listen on this port, default = find open port
  29. :param root_port: if this server does not have initial_peers, it will create a virtual dht node on this port.
  30. You can then use this node as initial peer for subsequent servers.
  31. :param verbose: whether to print server started / finished / terminated events
  32. :param start: if True, starts server right away and returns when server is ready for requests
  33. """
  34. if verbose and len(kwargs) != 0:
  35. print("Ignored kwargs:", kwargs)
  36. assert expert_cls in name_to_block
  37. num_handlers = num_handlers if num_handlers is not None else num_experts * 8
  38. device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
  39. # initialize dht
  40. dht = None
  41. if not no_dht:
  42. if not len(initial_peers):
  43. print("No initial peers provided. Starting additional dht as an initial peer.")
  44. dht_root = hivemind.DHT(initial_peers=initial_peers, start=True,
  45. listen_on=f"{hivemind.LOCALHOST}:{root_port or hivemind.find_open_port()}")
  46. print(f"Initializing DHT with port {dht_root.port}")
  47. initial_peers = [f"{hivemind.LOCALHOST}:{dht_root.port}"]
  48. else:
  49. print("Bootstrapping dht with peers:", initial_peers)
  50. if root_port is not None:
  51. print(f"Warning: root_port={root_port} will not be used since we already have peers.")
  52. dht = hivemind.DHT(initial_peers=initial_peers, start=True,
  53. listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}")
  54. if verbose:
  55. print(f"Running dht node on port {dht.port}")
  56. sample_input = name_to_input[expert_cls](4, hidden_dim)
  57. if isinstance(sample_input, tuple):
  58. args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg) for arg in sample_input)
  59. else:
  60. args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input),)
  61. # initialize experts
  62. experts = {}
  63. for i in range(num_experts):
  64. expert = name_to_block[expert_cls](hidden_dim)
  65. opt = torch.optim.SGD(expert.parameters(), 0.0 if no_optimizer else 0.05)
  66. expert_uid = f'{expert_prefix}{hivemind.DHT.UID_DELIMITER}{i + expert_offset}'
  67. experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
  68. args_schema=args_schema,
  69. outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
  70. max_batch_size=max_batch_size,
  71. )
  72. # actually start server
  73. server = hivemind.Server(
  74. dht, experts, listen_on=listen_on,
  75. num_connection_handlers=num_handlers, device=device)
  76. if start:
  77. server.run_in_background(await_ready=True)
  78. if verbose:
  79. print(f"Server started at {server.listen_on}")
  80. print(f"Got {num_experts} active experts of type {expert_cls}: {list(experts.keys())}")
  81. return server
  82. @contextmanager
  83. def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
  84. """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
  85. pipe, runners_pipe = mp.Pipe(duplex=True)
  86. runner = mp.get_context("spawn").Process(
  87. target=_server_runner, args=(runners_pipe, *args), kwargs=dict(verbose=verbose, **kwargs))
  88. try:
  89. runner.start()
  90. yield pipe.recv() # once the server is ready, runner will send us a tuple(hostname, port, dht port)
  91. pipe.send('SHUTDOWN') # on exit from context, send shutdown signal
  92. finally:
  93. try:
  94. runner.join(timeout=shutdown_timeout)
  95. finally:
  96. if verbose:
  97. print("Server failed to shutdown gracefully, terminating it the hard way...")
  98. runner.terminate()
  99. if verbose:
  100. print("Server terminated.")
  101. def _server_runner(pipe, *args, verbose, **kwargs):
  102. server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
  103. try:
  104. if server.dht is not None:
  105. dht_listen_on = hivemind.replace_port(server.dht.listen_on, server.dht.port)
  106. else:
  107. dht_listen_on = None
  108. pipe.send((server.listen_on, dht_listen_on))
  109. pipe.recv() # wait for shutdown signal
  110. finally:
  111. if verbose:
  112. print("Shutting down server...")
  113. server.shutdown()
  114. if verbose:
  115. print("Server shut down successfully.")
  116. if __name__ == '__main__':
  117. parser = argparse.ArgumentParser()
  118. parser.add_argument('--interface', type=str, default='0.0.0.0', required=False,
  119. help="'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6")
  120. parser.add_argument('--port', type=int, default=None, required=False, help="server will listen to this port")
  121. parser.add_argument('--num_experts', type=int, default=1, required=False, help="run this many identical experts")
  122. parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
  123. help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
  124. parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
  125. parser.add_argument('--num_handlers', type=int, default=None, required=False,
  126. help='server will use this many processes to handle incoming requests')
  127. parser.add_argument('--expert_prefix', type=str, default='expert', required=False,
  128. help='all expert uids will be {expert_prefix}.{index}')
  129. parser.add_argument('--expert_offset', type=int, default=0, required=False,
  130. help='expert uid will use indices in range(expert_offset, expert_offset + num_experts)')
  131. parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
  132. help='total num examples in the same batch will not exceed this value')
  133. parser.add_argument('--device', type=str, default=None, required=False,
  134. help='all experts will use this device in torch notation; default: cuda if available else cpu')
  135. parser.add_argument('--no_optimizer', action='store_true', help='if specified, all optimizers use learning rate=0')
  136. parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
  137. parser.add_argument('--initial_peers', type=str, default="[]", required=False, help='a list of peers that will'
  138. ' introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]')
  139. parser.add_argument('--dht_port', type=int, default=None, required=False, help='DHT node will listen on this port')
  140. parser.add_argument('--root_port', type=int, default=None, required=False, help='If this server does not have peers'
  141. ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
  142. parser.add_argument('--increase_file_limit', action='store_true', help='On *nix, this will increase the max number'
  143. ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
  144. args = vars(parser.parse_args())
  145. if args.pop('increase_file_limit'):
  146. soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
  147. try:
  148. print("Setting open file limit to soft={}, hard={}".format(max(soft, 2 ** 15), max(hard, 2 ** 15)))
  149. resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 2 ** 15), max(hard, 2 ** 15)))
  150. except:
  151. print("Could not increase open file limit, currently at soft={}, hard={}".format(soft, hard))
  152. args['initial_peers'] = eval(args['initial_peers'])
  153. try:
  154. server = make_dummy_server(**args, start=True, verbose=True)
  155. server.join()
  156. finally:
  157. server.shutdown()