run_server.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import argparse
  2. import multiprocessing as mp
  3. from contextlib import contextmanager
  4. import resource
  5. from typing import Tuple
  6. import torch
  7. import hivemind
  8. from test_utils.layers import name_to_block, name_to_input
  9. logger = hivemind.get_logger(__name__)
  10. def make_dummy_server(listen_on='0.0.0.0:*', num_experts=None, expert_uids=None, expert_cls='ffn', hidden_dim=1024,
  11. num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
  12. no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
  13. start=False, **kwargs) -> hivemind.Server:
  14. """
  15. Instantiate a server with several identical experts. See argparse comments below for details
  16. :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
  17. :param num_experts: run this many identical experts
  18. :param expert_prefix: all expert uids will be {expert_prefix}.{index}
  19. :param expert_offset: expert uid will use indices in range(expert_offset, expert_offset + num_experts)
  20. :param expert_uids: spawn experts with these exact uids, overrides num_experts, expert_prefix and expert_offset
  21. :param expert_cls: expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
  22. :param hidden_dim: main dimension for expert_cls
  23. :param num_handlers: server will use this many parallel processes to handle incoming requests
  24. :param max_batch_size: total num examples in the same batch will not exceed this value
  25. :param device: all experts will use this device in torch notation; default: cuda if available else cpu
  26. :param no_optimizer: if specified, all optimizers use learning rate=0
  27. :param no_dht: if specified, the server will not be attached to a dht
  28. :param initial_peers: a list of peers that will introduce this node to the dht,
  29. e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]'), default = no peers
  30. :param dht_port: DHT node will listen on this port, default = find open port
  31. :param root_port: if this server does not have initial_peers, it will create a virtual dht node on this port.
  32. You can then use this node as initial peer for subsequent servers.
  33. :param verbose: whether to print server started / finished / terminated events
  34. :param start: if True, starts server right away and returns when server is ready for requests
  35. """
  36. assert (expert_uids is None) != (num_experts is None and expert_prefix == 'expert' and expert_offset == 0), \
  37. "Please provide either expert uids *or* (num_experts, expert_prefix and expert_offset), not both"
  38. if verbose and len(kwargs) != 0:
  39. print("Ignored kwargs:", kwargs)
  40. assert expert_cls in name_to_block
  41. num_handlers = num_handlers if num_handlers is not None else num_experts * 8
  42. device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
  43. # initialize dht
  44. dht = None
  45. if not no_dht:
  46. if not len(initial_peers):
  47. logger.info("No initial peers provided. Starting additional dht as an initial peer.")
  48. dht_root = hivemind.DHT(initial_peers=initial_peers, start=True,
  49. listen_on=f"{hivemind.LOCALHOST}:{root_port or hivemind.find_open_port()}")
  50. logger.info(f"Initializing DHT with port {dht_root.port}")
  51. initial_peers = [f"{hivemind.LOCALHOST}:{dht_root.port}"]
  52. else:
  53. logger.info("Bootstrapping dht with peers:", initial_peers)
  54. if root_port is not None:
  55. logger.info(f"Warning: root_port={root_port} will not be used since we already have peers.")
  56. dht = hivemind.DHT(initial_peers=initial_peers, start=True,
  57. listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}")
  58. if verbose:
  59. logger.info(f"Running dht node on port {dht.port}")
  60. sample_input = name_to_input[expert_cls](4, hidden_dim)
  61. if isinstance(sample_input, tuple):
  62. args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg) for arg in sample_input)
  63. else:
  64. args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input),)
  65. # initialize experts
  66. if expert_uids is None:
  67. num_experts = num_experts if num_experts is not None else 1
  68. expert_uids = [f'{expert_prefix}{hivemind.DHT.UID_DELIMITER}{i + expert_offset}'
  69. for i in range(num_experts)]
  70. experts = {}
  71. for expert_uid in expert_uids:
  72. expert = name_to_block[expert_cls](hidden_dim)
  73. opt = torch.optim.SGD(expert.parameters(), 0.0 if no_optimizer else 0.05)
  74. experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
  75. args_schema=args_schema,
  76. outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
  77. max_batch_size=max_batch_size,
  78. )
  79. # actually start server
  80. server = hivemind.Server(
  81. dht, experts, listen_on=listen_on,
  82. num_connection_handlers=num_handlers, device=device)
  83. if start:
  84. server.run_in_background(await_ready=True)
  85. if verbose:
  86. logger.info(f"Server started at {server.listen_on}")
  87. logger.info(f"Got {len(experts)} active experts of type {expert_cls}: {list(experts.keys())}")
  88. return server
  89. @contextmanager
  90. def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
  91. """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
  92. pipe, runners_pipe = mp.Pipe(duplex=True)
  93. runner = mp.get_context("spawn").Process(
  94. target=_server_runner, args=(runners_pipe, *args), kwargs=dict(verbose=verbose, **kwargs))
  95. try:
  96. runner.start()
  97. yield pipe.recv() # once the server is ready, runner will send us a tuple(hostname, port, dht port)
  98. pipe.send('SHUTDOWN') # on exit from context, send shutdown signal
  99. finally:
  100. runner.join(timeout=shutdown_timeout)
  101. if runner.is_alive():
  102. if verbose:
  103. logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
  104. runner.kill()
  105. if verbose:
  106. logger.info("Server terminated.")
  107. def _server_runner(pipe, *args, verbose, **kwargs):
  108. server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
  109. try:
  110. if server.dht is not None:
  111. dht_listen_on = hivemind.replace_port(server.dht.listen_on, server.dht.port)
  112. else:
  113. dht_listen_on = None
  114. pipe.send((server.listen_on, dht_listen_on))
  115. pipe.recv() # wait for shutdown signal
  116. finally:
  117. if verbose:
  118. logger.info("Shutting down server...")
  119. server.shutdown()
  120. server.join()
  121. if verbose:
  122. logger.info("Server shut down successfully.")
  123. if __name__ == '__main__':
  124. # fmt:off
  125. parser = argparse.ArgumentParser()
  126. parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
  127. help="'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6")
  128. parser.add_argument('--num_experts', type=int, default=1, required=False, help="run this many identical experts")
  129. parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
  130. help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
  131. parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
  132. parser.add_argument('--num_handlers', type=int, default=None, required=False,
  133. help='server will use this many processes to handle incoming requests')
  134. parser.add_argument('--expert_prefix', type=str, default='expert', required=False,
  135. help='all expert uids will be {expert_prefix}.{index}')
  136. parser.add_argument('--expert_offset', type=int, default=0, required=False,
  137. help='expert uid will use indices in range(expert_offset, expert_offset + num_experts)')
  138. parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
  139. help='total num examples in the same batch will not exceed this value')
  140. parser.add_argument('--device', type=str, default=None, required=False,
  141. help='all experts will use this device in torch notation; default: cuda if available else cpu')
  142. parser.add_argument('--no_optimizer', action='store_true', help='if specified, all optimizers use learning rate=0')
  143. parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
  144. parser.add_argument('--initial_peers', type=str, default="[]", required=False, help='a list of peers that will'
  145. ' introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]')
  146. parser.add_argument('--dht_port', type=int, default=None, required=False, help='DHT node will listen on this port')
  147. parser.add_argument('--root_port', type=int, default=None, required=False, help='If this server does not have peers'
  148. ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
  149. parser.add_argument('--increase_file_limit', action='store_true', help='On *nix, this will increase the max number'
  150. ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
  151. # fmt:on
  152. args = vars(parser.parse_args())
  153. if args.pop('increase_file_limit'):
  154. soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
  155. try:
  156. print("Setting open file limit to soft={}, hard={}".format(max(soft, 2 ** 15), max(hard, 2 ** 15)))
  157. resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 2 ** 15), max(hard, 2 ** 15)))
  158. except:
  159. print("Could not increase open file limit, currently at soft={}, hard={}".format(soft, hard))
  160. args['initial_peers'] = eval(args['initial_peers'])
  161. try:
  162. server = make_dummy_server(**args, start=True, verbose=True)
  163. server.join()
  164. finally:
  165. server.shutdown()