__init__.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import multiprocessing as mp
  2. import multiprocessing.synchronize
  3. import threading
  4. from typing import Dict, Optional
  5. from hivemind.dht import DHT
  6. from hivemind.server.runtime import Runtime
  7. from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
  8. from hivemind.server.expert_backend import ExpertBackend
  9. from hivemind.server.checkpoint_saver import CheckpointSaver
  10. from hivemind.server.connection_handler import ConnectionHandler
  11. from hivemind.server.dht_handler import DHTHandlerThread
  12. from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
  13. logger = get_logger(__name__)
  14. class Server(threading.Thread):
  15. """
  16. Server allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
  17. After creation, a server should be started: see Server.run or Server.run_in_background.
  18. A working server does 3 things:
  19. - processes incoming forward/backward requests via Runtime (created by the server)
  20. - publishes updates to expert status every :update_period: seconds
  21. - follows orders from HivemindController - if it exists
  22. :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
  23. but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
  24. :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
  25. :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
  26. :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
  27. if too small for normal functioning, we recommend 4 handlers per expert backend.
  28. :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
  29. if dht is None, this parameter is ignored.
  30. :param start: if True, the server will immediately start as a background thread and returns control after server
  31. is ready (see .ready below)
  32. """
  33. def __init__(
  34. self, dht: Optional[DHT], expert_backends: Dict[str, ExpertBackend], listen_on: Endpoint = "0.0.0.0:*",
  35. num_connection_handlers: int = 1, update_period: int = 30, start=False, checkpoint_dir=None, **kwargs):
  36. super().__init__()
  37. self.dht, self.experts, self.update_period = dht, expert_backends, update_period
  38. if get_port(listen_on) is None:
  39. listen_on = replace_port(listen_on, new_port=find_open_port())
  40. self.listen_on, self.port = listen_on, get_port(listen_on)
  41. self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
  42. if checkpoint_dir is not None:
  43. self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
  44. else:
  45. self.checkpoint_saver = None
  46. self.runtime = Runtime(self.experts, **kwargs)
  47. if start:
  48. self.run_in_background(await_ready=True)
  49. def run(self):
  50. """
  51. Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
  52. runs Runtime (self.runtime) to process incoming requests.
  53. """
  54. if self.dht:
  55. if not self.dht.is_alive():
  56. self.dht.run_in_background(await_ready=True)
  57. dht_handler_thread = DHTHandlerThread(
  58. experts=self.experts, dht=self.dht, endpoint=self.listen_on, update_period=self.update_period)
  59. dht_handler_thread.start()
  60. if self.checkpoint_saver is not None:
  61. self.checkpoint_saver.start()
  62. for process in self.conn_handlers:
  63. if not process.is_alive():
  64. process.start()
  65. for process in self.conn_handlers:
  66. process.ready.wait()
  67. self.runtime.run()
  68. for process in self.conn_handlers:
  69. process.join()
  70. if self.dht:
  71. dht_handler_thread.stop.set()
  72. dht_handler_thread.join()
  73. if self.checkpoint_saver is not None:
  74. self.checkpoint_saver.stop.set()
  75. self.checkpoint_saver.join()
  76. def run_in_background(self, await_ready=True, timeout=None):
  77. """
  78. Starts Server in a background thread. if await_ready, this method will wait until background server
  79. is ready to process incoming requests or for :timeout: seconds max.
  80. """
  81. self.start()
  82. if await_ready and not self.ready.wait(timeout=timeout):
  83. raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
  84. @property
  85. def ready(self) -> mp.synchronize.Event:
  86. """
  87. An event (multiprocessing.Event) that is set when the server is ready to process requests.
  88. Example
  89. =======
  90. >>> server.start()
  91. >>> server.ready.wait(timeout=10)
  92. >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
  93. """
  94. return self.runtime.ready # mp.Event that is true if self is ready to process batches
  95. def shutdown(self):
  96. """
  97. Gracefully terminate a hivemind server, process-safe.
  98. Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
  99. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
  100. """
  101. self.ready.clear()
  102. for process in self.conn_handlers:
  103. process.terminate()
  104. if self.dht is not None:
  105. self.dht.shutdown()
  106. self.dht.join()
  107. self.runtime.shutdown()