__init__.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. from __future__ import annotations
  2. import multiprocessing as mp
  3. import multiprocessing.synchronize
  4. import threading
  5. from contextlib import contextmanager
  6. from functools import partial
  7. from pathlib import Path
  8. from typing import Dict, List, Optional, Tuple
  9. import torch
  10. from multiaddr import Multiaddr
  11. import hivemind
  12. from hivemind.dht import DHT
  13. from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
  14. from hivemind.moe.server.connection_handler import ConnectionHandler
  15. from hivemind.moe.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
  16. from hivemind.moe.server.expert_backend import ExpertBackend
  17. from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
  18. from hivemind.moe.server.layers import (
  19. add_custom_models_from_file,
  20. name_to_block,
  21. name_to_input,
  22. register_expert_class,
  23. schedule_name_to_scheduler,
  24. )
  25. from hivemind.moe.server.runtime import Runtime
  26. from hivemind.proto.runtime_pb2 import CompressionType
  27. from hivemind.utils import BatchTensorDescriptor, Endpoint, find_open_port, get_logger, get_port, replace_port
  28. logger = get_logger(__name__)
  29. class Server(threading.Thread):
  30. """
  31. Server allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
  32. After creation, a server should be started: see Server.run or Server.run_in_background.
  33. A working server does 3 things:
  34. - processes incoming forward/backward requests via Runtime (created by the server)
  35. - publishes updates to expert status every :update_period: seconds
  36. - follows orders from HivemindController - if it exists
  37. :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
  38. but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
  39. :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
  40. :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
  41. :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
  42. if too small for normal functioning, we recommend 4 handlers per expert backend.
  43. :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
  44. if dht is None, this parameter is ignored.
  45. :param start: if True, the server will immediately start as a background thread and returns control after server
  46. is ready (see .ready below)
  47. """
  48. def __init__(
  49. self,
  50. dht: Optional[DHT],
  51. expert_backends: Dict[str, ExpertBackend],
  52. listen_on: Endpoint = "0.0.0.0:*",
  53. num_connection_handlers: int = 1,
  54. update_period: int = 30,
  55. start=False,
  56. checkpoint_dir=None,
  57. **kwargs,
  58. ):
  59. super().__init__()
  60. self.dht, self.experts, self.update_period = dht, expert_backends, update_period
  61. if get_port(listen_on) is None:
  62. listen_on = replace_port(listen_on, new_port=find_open_port())
  63. self.listen_on, self.port = listen_on, get_port(listen_on)
  64. self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
  65. if checkpoint_dir is not None:
  66. self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
  67. else:
  68. self.checkpoint_saver = None
  69. self.runtime = Runtime(self.experts, **kwargs)
  70. if self.dht and self.experts:
  71. self.dht_handler_thread = DHTHandlerThread(
  72. experts=self.experts,
  73. dht=self.dht,
  74. endpoint=self.listen_on,
  75. update_period=self.update_period,
  76. daemon=True,
  77. )
  78. if start:
  79. self.run_in_background(await_ready=True)
  80. @classmethod
  81. def create(
  82. cls,
  83. listen_on="0.0.0.0:*",
  84. num_experts: int = None,
  85. expert_uids: str = None,
  86. expert_pattern: str = None,
  87. expert_cls="ffn",
  88. hidden_dim=1024,
  89. optim_cls=torch.optim.Adam,
  90. scheduler: str = "none",
  91. num_warmup_steps=None,
  92. num_total_steps=None,
  93. clip_grad_norm=None,
  94. num_handlers=None,
  95. min_batch_size=1,
  96. max_batch_size=4096,
  97. device=None,
  98. no_dht=False,
  99. initial_peers=(),
  100. checkpoint_dir: Optional[Path] = None,
  101. compression=CompressionType.NONE,
  102. stats_report_interval: Optional[int] = None,
  103. custom_module_path=None,
  104. *,
  105. start: bool,
  106. ) -> Server:
  107. """
  108. Instantiate a server with several identical experts. See argparse comments below for details
  109. :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
  110. :param num_experts: run this many identical experts
  111. :param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
  112. means "sample random experts between myprefix.0.0 and myprefix.255.255;
  113. :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
  114. :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
  115. :param hidden_dim: main dimension for expert_cls
  116. :param num_handlers: server will use this many parallel processes to handle incoming requests
  117. :param min_batch_size: total num examples in the same batch will be greater than this value
  118. :param max_batch_size: total num examples in the same batch will not exceed this value
  119. :param device: all experts will use this device in torch notation; default: cuda if available else cpu
  120. :param optim_cls: uses this optimizer to train all experts
  121. :param scheduler: if not `none`, the name of the expert LR scheduler
  122. :param num_warmup_steps: the number of warmup steps for LR schedule
  123. :param num_total_steps: the total number of steps for LR schedule
  124. :param clip_grad_norm: maximum gradient norm used for clipping
  125. :param no_dht: if specified, the server will not be attached to a dht
  126. :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
  127. :param checkpoint_dir: directory to save and load expert checkpoints
  128. :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
  129. hosted on this server. For a more fine-grained compression, start server in python and specify compression
  130. for each BatchTensorProto in ExpertBackend for the respective experts.
  131. :param start: if True, starts server right away and returns when server is ready for requests
  132. :param stats_report_interval: interval between two reports of batch processing performance statistics
  133. """
  134. if custom_module_path is not None:
  135. add_custom_models_from_file(custom_module_path)
  136. assert expert_cls in name_to_block
  137. if no_dht:
  138. dht = None
  139. else:
  140. dht = hivemind.DHT(initial_peers=initial_peers, start=True)
  141. visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
  142. logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
  143. assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
  144. num_experts is not None and expert_uids is None
  145. ), "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
  146. if expert_uids is None:
  147. if checkpoint_dir is not None:
  148. assert is_directory(checkpoint_dir)
  149. expert_uids = [
  150. child.name for child in checkpoint_dir.iterdir() if (child / "checkpoint_last.pt").exists()
  151. ]
  152. total_experts_in_checkpoint = len(expert_uids)
  153. logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
  154. if total_experts_in_checkpoint > num_experts:
  155. raise ValueError(
  156. f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
  157. f"which is smaller. Either increase num_experts or remove unneeded checkpoints."
  158. )
  159. else:
  160. expert_uids = []
  161. uids_to_generate = num_experts - len(expert_uids)
  162. if uids_to_generate > 0:
  163. logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
  164. expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))
  165. num_experts = len(expert_uids)
  166. num_handlers = num_handlers if num_handlers is not None else num_experts * 8
  167. optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
  168. device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  169. sample_input = name_to_input[expert_cls](3, hidden_dim)
  170. if isinstance(sample_input, tuple):
  171. args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
  172. else:
  173. args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
  174. scheduler = schedule_name_to_scheduler[scheduler]
  175. # initialize experts
  176. experts = {}
  177. for expert_uid in expert_uids:
  178. expert = name_to_block[expert_cls](hidden_dim)
  179. experts[expert_uid] = hivemind.ExpertBackend(
  180. name=expert_uid,
  181. expert=expert,
  182. args_schema=args_schema,
  183. optimizer=optim_cls(expert.parameters()),
  184. scheduler=scheduler,
  185. num_warmup_steps=num_warmup_steps,
  186. num_total_steps=num_total_steps,
  187. clip_grad_norm=clip_grad_norm,
  188. min_batch_size=min_batch_size,
  189. max_batch_size=max_batch_size,
  190. )
  191. if checkpoint_dir is not None:
  192. load_experts(experts, checkpoint_dir)
  193. return cls(
  194. dht,
  195. experts,
  196. listen_on=listen_on,
  197. num_connection_handlers=num_handlers,
  198. device=device,
  199. checkpoint_dir=checkpoint_dir,
  200. stats_report_interval=stats_report_interval,
  201. start=start,
  202. )
  203. def run(self):
  204. """
  205. Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
  206. runs Runtime (self.runtime) to process incoming requests.
  207. """
  208. logger.info(f"Server started at {self.listen_on}")
  209. logger.info(f"Got {len(self.experts)} experts:")
  210. for expert_name, backend in self.experts.items():
  211. num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
  212. logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
  213. if self.dht:
  214. if not self.dht.is_alive():
  215. self.dht.run_in_background(await_ready=True)
  216. if self.experts:
  217. self.dht_handler_thread.start()
  218. if self.checkpoint_saver is not None:
  219. self.checkpoint_saver.start()
  220. for process in self.conn_handlers:
  221. if not process.is_alive():
  222. process.start()
  223. process.ready.wait()
  224. try:
  225. self.runtime.run()
  226. finally:
  227. self.shutdown()
  228. def run_in_background(self, await_ready=True, timeout=None):
  229. """
  230. Starts Server in a background thread. if await_ready, this method will wait until background server
  231. is ready to process incoming requests or for :timeout: seconds max.
  232. """
  233. self.start()
  234. if await_ready and not self.ready.wait(timeout=timeout):
  235. raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
  236. @property
  237. def ready(self) -> mp.synchronize.Event:
  238. """
  239. An event (multiprocessing.Event) that is set when the server is ready to process requests.
  240. Example
  241. =======
  242. >>> server.start()
  243. >>> server.ready.wait(timeout=10)
  244. >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
  245. """
  246. return self.runtime.ready # mp.Event that is true if self is ready to process batches
  247. def shutdown(self):
  248. """
  249. Gracefully terminate the server, process-safe.
  250. Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
  251. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
  252. """
  253. self.ready.clear()
  254. for process in self.conn_handlers:
  255. process.terminate()
  256. process.join()
  257. logger.debug("Connection handlers terminated")
  258. if self.dht and self.experts:
  259. self.dht_handler_thread.stop.set()
  260. self.dht_handler_thread.join()
  261. if self.checkpoint_saver is not None:
  262. self.checkpoint_saver.stop.set()
  263. self.checkpoint_saver.join()
  264. if self.dht is not None:
  265. self.dht.shutdown()
  266. self.dht.join()
  267. logger.debug(f"Shutting down runtime")
  268. self.runtime.shutdown()
  269. logger.info("Server shutdown succesfully")
  270. @contextmanager
  271. def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, List[Multiaddr]]:
  272. """A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit"""
  273. pipe, runners_pipe = mp.Pipe(duplex=True)
  274. runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
  275. try:
  276. runner.start()
  277. # once the server is ready, runner will send us
  278. # either (False, exception) or (True, (server.listen_on, dht_maddrs))
  279. start_ok, data = pipe.recv()
  280. if start_ok:
  281. yield data
  282. pipe.send("SHUTDOWN") # on exit from context, send shutdown signal
  283. else:
  284. raise RuntimeError(f"Server failed to start: {data}")
  285. finally:
  286. runner.join(timeout=shutdown_timeout)
  287. if runner.is_alive():
  288. logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
  289. runner.kill()
  290. logger.info("Server terminated.")
  291. def _server_runner(pipe, *args, **kwargs):
  292. try:
  293. server = Server.create(*args, start=True, **kwargs)
  294. except Exception as e:
  295. logger.exception(f"Encountered an exception when starting a server: {e}")
  296. pipe.send((False, f"{type(e).__name__} {e}"))
  297. return
  298. try:
  299. dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
  300. pipe.send((True, (server.listen_on, dht_maddrs)))
  301. pipe.recv() # wait for shutdown signal
  302. finally:
  303. logger.info("Shutting down server...")
  304. server.shutdown()
  305. server.join()
  306. logger.info("Server shut down.")