run_server.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import resource
  2. from contextlib import contextmanager
  3. import multiprocessing as mp
  4. import argparse
  5. import torch
  6. import tesseract
  7. from .layers import name_to_block
  8. def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
  9. expert_prefix='expert.', expert_offset=0, max_batch_size=16384, device=None, no_optimizer=False,
  10. no_network=False, initial_peers=(), network_port=None, root_port=None, verbose=True, start=True,
  11. **kwargs) -> tesseract.TesseractServer:
  12. """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
  13. if verbose and len(kwargs) != 0:
  14. print("Ignored kwargs:", kwargs)
  15. assert expert_cls in name_to_block
  16. num_handlers = num_handlers if num_handlers is not None else num_experts * 8
  17. device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
  18. # initialize network
  19. network = None
  20. if not no_network:
  21. if not len(initial_peers):
  22. print("No initial peers provided. Starting additional network as an initial peer.")
  23. network = tesseract.TesseractNetwork(
  24. *initial_peers, port=root_port or tesseract.find_open_port(), start=True)
  25. print(f"Running DHT root on port {network.port}")
  26. else:
  27. print("Bootstrapping dht with peers:", initial_peers)
  28. if root_port is not None:
  29. print(f"Warning: root_port={root_port} will not be used since we already have some peers.")
  30. network = tesseract.TesseractNetwork(
  31. *initial_peers, port=network_port or tesseract.find_open_port(), start=True)
  32. if verbose:
  33. print(f"Running network node on port {network.port}")
  34. # initialize experts
  35. experts = {}
  36. for i in range(num_experts):
  37. expert = torch.jit.script(name_to_block[expert_cls](hidden_dim))
  38. opt = torch.optim.SGD(expert.parameters(), 0.0) if no_optimizer else torch.optim.Adam(expert.parameters())
  39. expert_uid = f'{expert_prefix}{i + expert_offset}'
  40. experts[expert_uid] = tesseract.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
  41. args_schema=(tesseract.BatchTensorProto(hidden_dim),),
  42. outputs_schema=tesseract.BatchTensorProto(hidden_dim),
  43. max_batch_size=max_batch_size,
  44. )
  45. # actually start server
  46. server = tesseract.TesseractServer(
  47. network, experts, addr=host, port=port or tesseract.find_open_port(),
  48. conn_handler_processes=num_handlers, device=device)
  49. if start:
  50. server.run_in_background(await_ready=True)
  51. if verbose:
  52. print(f"Server started at {server.addr}:{server.port}")
  53. print(f"Got {num_experts} active experts of type {expert_cls}: {list(experts.keys())}")
  54. return server
  55. @contextmanager
  56. def background_server(*args, verbose=True, **kwargs):
  57. """ Runs server in a background process and returns a reference to it. """
  58. recv_addr, send_addr = mp.Pipe(duplex=True)
  59. trigger_shutdown = mp.Event()
  60. def server_runner():
  61. try:
  62. server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
  63. send_addr.send((server.addr, server.port))
  64. trigger_shutdown.wait()
  65. finally:
  66. if verbose:
  67. print("Shutting down server...")
  68. trigger_shutdown.set() # if server failed internally, set the shutdown trigger anyway
  69. server.shutdown()
  70. if verbose:
  71. print("Server shut down successfully.")
  72. try:
  73. runner = mp.Process(target=server_runner)
  74. runner.start()
  75. yield recv_addr.recv() # yield tuple(hostname, port)
  76. finally:
  77. trigger_shutdown.set()
  78. runner.join()
  79. if __name__ == '__main__':
  80. parser = argparse.ArgumentParser()
  81. parser.add_argument('--host', type=str, default='0.0.0.0', required=False)
  82. parser.add_argument('--port', type=int, default=None, required=False)
  83. parser.add_argument('--num_experts', type=int, default=1, required=False)
  84. parser.add_argument('--expert_cls', type=str, default='ffn', required=False)
  85. parser.add_argument('--hidden_dim', type=int, default=1024, required=False)
  86. parser.add_argument('--num_handlers', type=int, default=None, required=False)
  87. parser.add_argument('--expert_prefix', type=str, default='expert.', required=False)
  88. parser.add_argument('--expert_offset', type=int, default=0, required=False)
  89. parser.add_argument('--max_batch_size', type=int, default=16384, required=False)
  90. parser.add_argument('--device', type=str, default=None, required=False)
  91. parser.add_argument('--no_optimizer', action='store_true')
  92. parser.add_argument('--no_network', action='store_true')
  93. parser.add_argument('--initial_peers', type=str, default="[]", required=False)
  94. parser.add_argument('--network_port', type=int, default=None, required=False)
  95. parser.add_argument('--root_port', type=int, default=None, required=False)
  96. parser.add_argument('--increase_file_limit', action='store_true')
  97. args = vars(parser.parse_args())
  98. if args.pop('increase_file_limit'):
  99. soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
  100. try:
  101. print("Setting open file limit to soft={}, hard={}".format(max(soft, 2 ** 15), max(hard, 2 ** 15)))
  102. resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 2 ** 15), max(hard, 2 ** 15)))
  103. except:
  104. print("Could not increase open file limit, currently at soft={}, hard={}".format(soft, hard))
  105. args['initial_peers'] = eval(args['initial_peers'])
  106. try:
  107. server = make_dummy_server(**args, start=True, verbose=True)
  108. server.join()
  109. finally:
  110. server.shutdown()