run_server.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import resource
  2. from contextlib import contextmanager
  3. import multiprocessing as mp
  4. import argparse
  5. import torch
  6. import hivemind
  7. from .layers import name_to_block, name_to_input
  8. def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024,
  9. num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
  10. no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
  11. UID_DELIMETER=hivemind.DHT.UID_DELIMETER, start=False, **kwargs) -> hivemind.Server:
  12. """
  13. Instantiate a server with several identical experts. See argparse comments below for details
  14. :param interface: 'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6
  15. :param port: main server will listen to this port, default = find open port
  16. :param num_experts: run this many identical experts
  17. :param expert_cls: expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
  18. :param hidden_dim: main dimension for expert_cls
  19. :param num_handlers: server will use this many parallel processes to handle incoming requests
  20. :param expert_prefix: all expert uids will be {expert_prefix}.{index}
  21. :param expert_offset: expert uid will use indices in range(expert_offset, expert_offset + num_experts)
  22. :param max_batch_size: total num examples in the same batch will not exceed this value
  23. :param device: all experts will use this device in torch notation; default: cuda if available else cpu
  24. :param no_optimizer: if specified, all optimizers use learning rate=0
  25. :param no_dht: if specified, the server will not be attached to a dht
  26. :param initial_peers: a list of peers that will introduce this node to the dht,
  27. e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]'), default = no peers
  28. :param dht_port: DHT node will listen on this port, default = find open port
  29. :param root_port: if this server does not have initial_peers, it will create a virtual dht node on this port.
  30. You can then use this node as initial peer for subsequent servers.
  31. :param verbose: whether to print server started / finished / terminated events
  32. :param start: if True, starts server right away and returns when server is ready for requests
  33. """
  34. if verbose and len(kwargs) != 0:
  35. print("Ignored kwargs:", kwargs)
  36. assert expert_cls in name_to_block
  37. num_handlers = num_handlers if num_handlers is not None else num_experts * 8
  38. device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
  39. # initialize dht
  40. dht = None
  41. if not no_dht:
  42. if not len(initial_peers):
  43. print("No initial peers provided. Starting additional dht as an initial peer.")
  44. dht_root = hivemind.DHT(
  45. *initial_peers, listen_on=f"{hivemind.LOCALHOST}:{root_port or hivemind.find_open_port()}", start=True)
  46. print(f"Initializing DHT with port {dht_root.port}")
  47. initial_peers = [f"{hivemind.LOCALHOST}:{dht_root.port}"]
  48. else:
  49. print("Bootstrapping dht with peers:", initial_peers)
  50. if root_port is not None:
  51. print(f"Warning: root_port={root_port} will not be used since we already have peers.")
  52. dht = hivemind.DHT(
  53. *initial_peers, listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}", start=True)
  54. if verbose:
  55. print(f"Running dht node on port {dht.port}")
  56. sample_input = name_to_input[expert_cls](4, hidden_dim)
  57. if isinstance(sample_input, tuple):
  58. args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg) for arg in sample_input)
  59. else:
  60. args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input),)
  61. # initialize experts
  62. experts = {}
  63. for i in range(num_experts):
  64. expert = name_to_block[expert_cls](hidden_dim)
  65. opt = torch.optim.SGD(expert.parameters(), 0.0) if no_optimizer else torch.optim.Adam(expert.parameters())
  66. expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
  67. experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
  68. args_schema=args_schema,
  69. outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
  70. max_batch_size=max_batch_size,
  71. )
  72. # actually start server
  73. server = hivemind.Server(
  74. dht, experts, addr=interface, port=port or hivemind.find_open_port(),
  75. conn_handler_processes=num_handlers, device=device)
  76. if start:
  77. server.run_in_background(await_ready=True)
  78. if verbose:
  79. print(f"Server started at {server.addr}:{server.port}")
  80. print(f"Got {num_experts} active experts of type {expert_cls}: {list(experts.keys())}")
  81. return server
  82. @contextmanager
  83. def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs):
  84. """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
  85. pipe, runners_pipe = mp.Pipe(duplex=True)
  86. runner = mp.get_context("spawn").Process(
  87. target=_server_runner, args=(runners_pipe, *args), kwargs=dict(verbose=verbose, **kwargs))
  88. try:
  89. runner.start()
  90. yield pipe.recv() # once the server is ready, runner will send us a tuple(hostname, port, dht port)
  91. pipe.send('SHUTDOWN') # on exit from context, send shutdown signal
  92. finally:
  93. try:
  94. runner.join(timeout=shutdown_timeout)
  95. finally:
  96. if verbose:
  97. print("Server failed to shutdown gracefully, terminating it the hard way...")
  98. runner.terminate()
  99. if verbose:
  100. print("Server terminated.")
  101. def _server_runner(pipe, *args, verbose, **kwargs):
  102. server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
  103. try:
  104. dht_port = server.dht.port if server.dht is not None else None
  105. pipe.send((server.addr, server.port, dht_port))
  106. pipe.recv() # wait for shutdown signal
  107. finally:
  108. if verbose:
  109. print("Shutting down server...")
  110. server.shutdown()
  111. if verbose:
  112. print("Server shut down successfully.")
  113. if __name__ == '__main__':
  114. parser = argparse.ArgumentParser()
  115. parser.add_argument('--interface', type=str, default='0.0.0.0', required=False,
  116. help="'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6")
  117. parser.add_argument('--port', type=int, default=None, required=False, help="server will listen to this port")
  118. parser.add_argument('--num_experts', type=int, default=1, required=False, help="run this many identical experts")
  119. parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
  120. help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
  121. parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
  122. parser.add_argument('--num_handlers', type=int, default=None, required=False,
  123. help='server will use this many processes to handle incoming requests')
  124. parser.add_argument('--expert_prefix', type=str, default='expert', required=False,
  125. help='all expert uids will be {expert_prefix}.{index}')
  126. parser.add_argument('--expert_offset', type=int, default=0, required=False,
  127. help='expert uid will use indices in range(expert_offset, expert_offset + num_experts)')
  128. parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
  129. help='total num examples in the same batch will not exceed this value')
  130. parser.add_argument('--device', type=str, default=None, required=False,
  131. help='all experts will use this device in torch notation; default: cuda if available else cpu')
  132. parser.add_argument('--no_optimizer', action='store_true', help='if specified, all optimizers use learning rate=0')
  133. parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
  134. parser.add_argument('--initial_peers', type=str, default="[]", required=False, help='a list of peers that will'
  135. ' introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]')
  136. parser.add_argument('--dht_port', type=int, default=None, required=False, help='DHT node will listen on this port')
  137. parser.add_argument('--root_port', type=int, default=None, required=False, help='If this server does not have peers'
  138. ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
  139. parser.add_argument('--increase_file_limit', action='store_true', help='On *nix, this will increase the max number'
  140. ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
  141. args = vars(parser.parse_args())
  142. if args.pop('increase_file_limit'):
  143. soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
  144. try:
  145. print("Setting open file limit to soft={}, hard={}".format(max(soft, 2 ** 15), max(hard, 2 ** 15)))
  146. resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 2 ** 15), max(hard, 2 ** 15)))
  147. except:
  148. print("Could not increase open file limit, currently at soft={}, hard={}".format(soft, hard))
  149. args['initial_peers'] = eval(args['initial_peers'])
  150. try:
  151. server = make_dummy_server(**args, start=True, verbose=True)
  152. server.join()
  153. finally:
  154. server.shutdown()