__init__.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. from __future__ import annotations
  2. import multiprocessing as mp
  3. import multiprocessing.synchronize
  4. import threading
  5. from contextlib import contextmanager
  6. from functools import partial
  7. from typing import Dict, Optional, Tuple
  8. from pathlib import Path
  9. import torch
  10. import hivemind
  11. from hivemind.dht import DHT
  12. from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
  13. from hivemind.moe.server.checkpoints import CheckpointSaver, load_experts, is_directory
  14. from hivemind.moe.server.connection_handler import ConnectionHandler
  15. from hivemind.moe.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
  16. from hivemind.moe.server.expert_backend import ExpertBackend
  17. from hivemind.moe.server.layers import name_to_block, name_to_input, register_expert_class
  18. from hivemind.moe.server.layers import add_custom_models_from_file, schedule_name_to_scheduler
  19. from hivemind.moe.server.runtime import Runtime
  20. from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger, BatchTensorDescriptor
  21. from hivemind.proto.runtime_pb2 import CompressionType
  22. logger = get_logger(__name__)
  23. class Server(threading.Thread):
  24. """
  25. Server allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
  26. After creation, a server should be started: see Server.run or Server.run_in_background.
  27. A working server does 3 things:
  28. - processes incoming forward/backward requests via Runtime (created by the server)
  29. - publishes updates to expert status every :update_period: seconds
  30. - follows orders from HivemindController - if it exists
  31. :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
  32. but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
  33. :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
  34. :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
  35. :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
  36. if too small for normal functioning, we recommend 4 handlers per expert backend.
  37. :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
  38. if dht is None, this parameter is ignored.
  39. :param start: if True, the server will immediately start as a background thread and returns control after server
  40. is ready (see .ready below)
  41. """
  42. def __init__(
  43. self, dht: Optional[DHT], expert_backends: Dict[str, ExpertBackend], listen_on: Endpoint = "0.0.0.0:*",
  44. num_connection_handlers: int = 1, update_period: int = 30, start=False, checkpoint_dir=None, **kwargs):
  45. super().__init__()
  46. self.dht, self.experts, self.update_period = dht, expert_backends, update_period
  47. if get_port(listen_on) is None:
  48. listen_on = replace_port(listen_on, new_port=find_open_port())
  49. self.listen_on, self.port = listen_on, get_port(listen_on)
  50. self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
  51. if checkpoint_dir is not None:
  52. self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
  53. else:
  54. self.checkpoint_saver = None
  55. self.runtime = Runtime(self.experts, **kwargs)
  56. if self.dht and self.experts:
  57. self.dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht, endpoint=self.listen_on,
  58. update_period=self.update_period, daemon=True)
  59. if start:
  60. self.run_in_background(await_ready=True)
  61. @classmethod
  62. def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
  63. expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
  64. num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1,
  65. max_batch_size=4096, device=None, no_dht=False, initial_peers=(), dht_port=None,
  66. checkpoint_dir: Optional[Path] = None, compression=CompressionType.NONE,
  67. stats_report_interval: Optional[int] = None, custom_module_path=None, *, start: bool) -> Server:
  68. """
  69. Instantiate a server with several identical experts. See argparse comments below for details
  70. :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
  71. :param num_experts: run this many identical experts
  72. :param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
  73. means "sample random experts between myprefix.0.0 and myprefix.255.255;
  74. :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
  75. :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
  76. :param hidden_dim: main dimension for expert_cls
  77. :param num_handlers: server will use this many parallel processes to handle incoming requests
  78. :param min_batch_size: total num examples in the same batch will be greater than this value
  79. :param max_batch_size: total num examples in the same batch will not exceed this value
  80. :param device: all experts will use this device in torch notation; default: cuda if available else cpu
  81. :param optim_cls: uses this optimizer to train all experts
  82. :param scheduler: if not `none`, the name of the expert LR scheduler
  83. :param num_warmup_steps: the number of warmup steps for LR schedule
  84. :param num_total_steps: the total number of steps for LR schedule
  85. :param clip_grad_norm: maximum gradient norm used for clipping
  86. :param no_dht: if specified, the server will not be attached to a dht
  87. :param initial_peers: a list of peers that will introduce this node to the dht,\
  88. e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers
  89. :param dht_port: DHT node will listen on this port, default = find open port
  90. You can then use this node as initial peer for subsequent servers.
  91. :param checkpoint_dir: directory to save and load expert checkpoints
  92. :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
  93. hosted on this server. For a more fine-grained compression, start server in python and specify compression
  94. for each BatchTensorProto in ExpertBackend for the respective experts.
  95. :param start: if True, starts server right away and returns when server is ready for requests
  96. :param stats_report_interval: interval between two reports of batch processing performance statistics
  97. """
  98. if custom_module_path is not None:
  99. add_custom_models_from_file(custom_module_path)
  100. assert expert_cls in name_to_block
  101. if no_dht:
  102. dht = None
  103. else:
  104. dht_endpoint = replace_port(listen_on, dht_port or hivemind.find_open_port())
  105. dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
  106. logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
  107. assert ((expert_pattern is None and num_experts is None and expert_uids is not None) or
  108. (num_experts is not None and expert_uids is None)), \
  109. "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
  110. if expert_uids is None:
  111. if checkpoint_dir is not None:
  112. assert is_directory(checkpoint_dir)
  113. expert_uids = [child.name for child in checkpoint_dir.iterdir() if
  114. (child / 'checkpoint_last.pt').exists()]
  115. total_experts_in_checkpoint = len(expert_uids)
  116. logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
  117. if total_experts_in_checkpoint > num_experts:
  118. raise ValueError(
  119. f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
  120. f"which is smaller. Either increase num_experts or remove unneeded checkpoints.")
  121. else:
  122. expert_uids = []
  123. uids_to_generate = num_experts - len(expert_uids)
  124. if uids_to_generate > 0:
  125. logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
  126. expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))
  127. num_experts = len(expert_uids)
  128. num_handlers = num_handlers if num_handlers is not None else num_experts * 8
  129. optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
  130. device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
  131. sample_input = name_to_input[expert_cls](3, hidden_dim)
  132. if isinstance(sample_input, tuple):
  133. args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
  134. else:
  135. args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
  136. scheduler = schedule_name_to_scheduler[scheduler]
  137. # initialize experts
  138. experts = {}
  139. for expert_uid in expert_uids:
  140. expert = name_to_block[expert_cls](hidden_dim)
  141. experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert,
  142. args_schema=args_schema,
  143. optimizer=optim_cls(expert.parameters()),
  144. scheduler=scheduler,
  145. num_warmup_steps=num_warmup_steps,
  146. num_total_steps=num_total_steps,
  147. clip_grad_norm=clip_grad_norm,
  148. min_batch_size=min_batch_size,
  149. max_batch_size=max_batch_size)
  150. if checkpoint_dir is not None:
  151. load_experts(experts, checkpoint_dir)
  152. return cls(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
  153. checkpoint_dir=checkpoint_dir, stats_report_interval=stats_report_interval, start=start)
  154. def run(self):
  155. """
  156. Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
  157. runs Runtime (self.runtime) to process incoming requests.
  158. """
  159. logger.info(f"Server started at {self.listen_on}")
  160. logger.info(f"Got {len(self.experts)} experts:")
  161. for expert_name, backend in self.experts.items():
  162. num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
  163. logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
  164. if self.dht:
  165. if not self.dht.is_alive():
  166. self.dht.run_in_background(await_ready=True)
  167. if self.experts:
  168. self.dht_handler_thread.start()
  169. if self.checkpoint_saver is not None:
  170. self.checkpoint_saver.start()
  171. for process in self.conn_handlers:
  172. if not process.is_alive():
  173. process.start()
  174. process.ready.wait()
  175. try:
  176. self.runtime.run()
  177. finally:
  178. self.shutdown()
  179. def run_in_background(self, await_ready=True, timeout=None):
  180. """
  181. Starts Server in a background thread. if await_ready, this method will wait until background server
  182. is ready to process incoming requests or for :timeout: seconds max.
  183. """
  184. self.start()
  185. if await_ready and not self.ready.wait(timeout=timeout):
  186. raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
  187. @property
  188. def ready(self) -> mp.synchronize.Event:
  189. """
  190. An event (multiprocessing.Event) that is set when the server is ready to process requests.
  191. Example
  192. =======
  193. >>> server.start()
  194. >>> server.ready.wait(timeout=10)
  195. >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
  196. """
  197. return self.runtime.ready # mp.Event that is true if self is ready to process batches
  198. def shutdown(self):
  199. """
  200. Gracefully terminate the server, process-safe.
  201. Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
  202. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
  203. """
  204. self.ready.clear()
  205. for process in self.conn_handlers:
  206. process.terminate()
  207. process.join()
  208. logger.debug("Connection handlers terminated")
  209. if self.dht and self.experts:
  210. self.dht_handler_thread.stop.set()
  211. self.dht_handler_thread.join()
  212. if self.checkpoint_saver is not None:
  213. self.checkpoint_saver.stop.set()
  214. self.checkpoint_saver.join()
  215. if self.dht is not None:
  216. self.dht.shutdown()
  217. self.dht.join()
  218. logger.debug(f"Shutting down runtime")
  219. self.runtime.shutdown()
  220. logger.info("Server shutdown succesfully")
  221. @contextmanager
  222. def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
  223. """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
  224. pipe, runners_pipe = mp.Pipe(duplex=True)
  225. runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
  226. try:
  227. runner.start()
  228. # once the server is ready, runner will send us either (False, exception) or (True, (server_port, dht_port))
  229. start_ok, data = pipe.recv()
  230. if start_ok:
  231. yield data
  232. pipe.send('SHUTDOWN') # on exit from context, send shutdown signal
  233. else:
  234. raise RuntimeError(f"Server failed to start: {data}")
  235. finally:
  236. runner.join(timeout=shutdown_timeout)
  237. if runner.is_alive():
  238. logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
  239. runner.kill()
  240. logger.info("Server terminated.")
  241. def _server_runner(pipe, *args, **kwargs):
  242. try:
  243. server = Server.create(*args, start=True, **kwargs)
  244. except Exception as e:
  245. logger.exception(f"Encountered an exception when starting a server: {e}")
  246. pipe.send((False, f'{type(e).__name__} {e}'))
  247. return
  248. try:
  249. if server.dht is not None:
  250. dht_listen_on = hivemind.replace_port(server.dht.listen_on, server.dht.port)
  251. else:
  252. dht_listen_on = None
  253. pipe.send((True, (server.listen_on, dht_listen_on)))
  254. pipe.recv() # wait for shutdown signal
  255. finally:
  256. logger.info("Shutting down server...")
  257. server.shutdown()
  258. server.join()
  259. logger.info("Server shut down.")