server.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. from __future__ import annotations
  2. import multiprocessing as mp
  3. import random
  4. import threading
  5. from contextlib import contextmanager
  6. from functools import partial
  7. from pathlib import Path
  8. from typing import Dict, List, Optional, Tuple
  9. import torch
  10. from multiaddr import Multiaddr
  11. from hivemind.dht import DHT
  12. from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
  13. from hivemind.moe.server.connection_handler import ConnectionHandler
  14. from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts
  15. from hivemind.moe.server.expert_backend import ExpertBackend
  16. from hivemind.moe.server.expert_uid import UID_DELIMITER
  17. from hivemind.moe.server.layers import (
  18. add_custom_models_from_file,
  19. name_to_block,
  20. name_to_input,
  21. schedule_name_to_scheduler,
  22. )
  23. from hivemind.moe.server.runtime import Runtime
  24. from hivemind.p2p import PeerInfo
  25. from hivemind.proto.runtime_pb2 import CompressionType
  26. from hivemind.utils.logging import get_logger
  27. from hivemind.utils.tensor_descr import BatchTensorDescriptor
  28. logger = get_logger(__name__)
  29. class Server(threading.Thread):
  30. """
  31. Server allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
  32. After creation, a server should be started: see Server.run or Server.run_in_background.
  33. A working server does 3 things:
  34. - processes incoming forward/backward requests via Runtime (created by the server)
  35. - publishes updates to expert status every :update_period: seconds
  36. - follows orders from HivemindController - if it exists
  37. :type dht: DHT.
  38. :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
  39. :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
  40. :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
  41. if too small for normal functioning, we recommend 4 handlers per expert backend.
  42. :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
  43. if dht is None, this parameter is ignored.
  44. :param start: if True, the server will immediately start as a background thread and returns control after server
  45. is ready (see .ready below)
  46. """
  47. def __init__(
  48. self,
  49. dht: DHT,
  50. expert_backends: Dict[str, ExpertBackend],
  51. num_connection_handlers: int = 1,
  52. update_period: int = 30,
  53. start=False,
  54. checkpoint_dir=None,
  55. **kwargs,
  56. ):
  57. super().__init__()
  58. self.dht, self.experts, self.update_period = dht, expert_backends, update_period
  59. self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(num_connection_handlers)]
  60. if checkpoint_dir is not None:
  61. self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
  62. else:
  63. self.checkpoint_saver = None
  64. self.runtime = Runtime(self.experts, **kwargs)
  65. if self.experts:
  66. self.dht_handler_thread = DHTHandlerThread(
  67. experts=self.experts,
  68. dht=self.dht,
  69. peer_id=self.dht.peer_id,
  70. update_period=self.update_period,
  71. daemon=True,
  72. )
  73. if start:
  74. self.run_in_background(await_ready=True)
  75. @classmethod
  76. def create(
  77. cls,
  78. num_experts: int = None,
  79. expert_uids: str = None,
  80. expert_pattern: str = None,
  81. expert_cls="ffn",
  82. hidden_dim=1024,
  83. optim_cls=torch.optim.Adam,
  84. scheduler: str = "none",
  85. num_warmup_steps=None,
  86. num_total_steps=None,
  87. clip_grad_norm=None,
  88. num_handlers=None,
  89. min_batch_size=1,
  90. max_batch_size=4096,
  91. device=None,
  92. initial_peers=(),
  93. checkpoint_dir: Optional[Path] = None,
  94. compression=CompressionType.NONE,
  95. stats_report_interval: Optional[int] = None,
  96. custom_module_path=None,
  97. *,
  98. start: bool,
  99. **kwargs,
  100. ) -> Server:
  101. """
  102. Instantiate a server with several identical experts. See argparse comments below for details
  103. :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
  104. :param num_experts: run this many identical experts
  105. :param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
  106. means "sample random experts between myprefix.0.0 and myprefix.255.255;
  107. :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
  108. :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
  109. :param hidden_dim: main dimension for expert_cls
  110. :param num_handlers: server will use this many parallel processes to handle incoming requests
  111. :param min_batch_size: total num examples in the same batch will be greater than this value
  112. :param max_batch_size: total num examples in the same batch will not exceed this value
  113. :param device: all experts will use this device in torch notation; default: cuda if available else cpu
  114. :param optim_cls: uses this optimizer to train all experts
  115. :param scheduler: if not `none`, the name of the expert LR scheduler
  116. :param num_warmup_steps: the number of warmup steps for LR schedule
  117. :param num_total_steps: the total number of steps for LR schedule
  118. :param clip_grad_norm: maximum gradient norm used for clipping
  119. :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
  120. :param checkpoint_dir: directory to save and load expert checkpoints
  121. :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
  122. hosted on this server. For a more fine-grained compression, start server in python and specify compression
  123. for each BatchTensorProto in ExpertBackend for the respective experts.
  124. :param start: if True, starts server right away and returns when server is ready for requests
  125. :param stats_report_interval: interval between two reports of batch processing performance statistics
  126. :param kwargs: any other params will be forwarded to DHT upon creation
  127. """
  128. if custom_module_path is not None:
  129. add_custom_models_from_file(custom_module_path)
  130. assert expert_cls in name_to_block
  131. dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
  132. visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
  133. logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
  134. assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
  135. num_experts is not None and expert_uids is None
  136. ), "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
  137. if expert_uids is None:
  138. if checkpoint_dir is not None:
  139. assert is_directory(checkpoint_dir)
  140. expert_uids = [
  141. child.name for child in checkpoint_dir.iterdir() if (child / "checkpoint_last.pt").exists()
  142. ]
  143. total_experts_in_checkpoint = len(expert_uids)
  144. logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
  145. if total_experts_in_checkpoint > num_experts:
  146. raise ValueError(
  147. f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
  148. f"which is smaller. Either increase num_experts or remove unneeded checkpoints."
  149. )
  150. else:
  151. expert_uids = []
  152. uids_to_generate = num_experts - len(expert_uids)
  153. if uids_to_generate > 0:
  154. logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
  155. expert_uids.extend(_generate_uids(uids_to_generate, expert_pattern, dht))
  156. num_experts = len(expert_uids)
  157. num_handlers = num_handlers if num_handlers is not None else num_experts * 8
  158. optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
  159. device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  160. sample_input = name_to_input[expert_cls](3, hidden_dim)
  161. if isinstance(sample_input, tuple):
  162. args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
  163. else:
  164. args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
  165. scheduler = schedule_name_to_scheduler[scheduler]
  166. # initialize experts
  167. experts = {}
  168. for expert_uid in expert_uids:
  169. expert = name_to_block[expert_cls](hidden_dim)
  170. experts[expert_uid] = ExpertBackend(
  171. name=expert_uid,
  172. expert=expert,
  173. args_schema=args_schema,
  174. optimizer=optim_cls(expert.parameters()),
  175. scheduler=scheduler,
  176. num_warmup_steps=num_warmup_steps,
  177. num_total_steps=num_total_steps,
  178. clip_grad_norm=clip_grad_norm,
  179. min_batch_size=min_batch_size,
  180. max_batch_size=max_batch_size,
  181. )
  182. if checkpoint_dir is not None:
  183. load_experts(experts, checkpoint_dir)
  184. return cls(
  185. dht,
  186. experts,
  187. num_connection_handlers=num_handlers,
  188. device=device,
  189. checkpoint_dir=checkpoint_dir,
  190. stats_report_interval=stats_report_interval,
  191. start=start,
  192. )
  193. def run(self):
  194. """
  195. Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
  196. runs Runtime (self.runtime) to process incoming requests.
  197. """
  198. logger.info(f"Server started with {len(self.experts)} experts:")
  199. for expert_name, backend in self.experts.items():
  200. num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
  201. logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
  202. if not self.dht.is_alive():
  203. self.dht.run_in_background(await_ready=True)
  204. if self.experts:
  205. self.dht_handler_thread.start()
  206. if self.checkpoint_saver is not None:
  207. self.checkpoint_saver.start()
  208. for process in self.conn_handlers:
  209. if not process.is_alive():
  210. process.start()
  211. process.ready.result()
  212. try:
  213. self.runtime.run()
  214. finally:
  215. self.shutdown()
  216. def run_in_background(self, await_ready=True, timeout=None):
  217. """
  218. Starts Server in a background thread. if await_ready, this method will wait until background server
  219. is ready to process incoming requests or for :timeout: seconds max.
  220. """
  221. self.start()
  222. if await_ready and not self.ready.wait(timeout=timeout):
  223. raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
  224. @property
  225. def ready(self) -> mp.synchronize.Event:
  226. """
  227. An event (multiprocessing.Event) that is set when the server is ready to process requests.
  228. Example
  229. =======
  230. >>> server.start()
  231. >>> server.ready.wait(timeout=10)
  232. >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
  233. """
  234. return self.runtime.ready # mp.Event that is true if self is ready to process batches
  235. def shutdown(self):
  236. """
  237. Gracefully terminate the server, process-safe.
  238. Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
  239. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
  240. """
  241. self.ready.clear()
  242. for process in self.conn_handlers:
  243. process.terminate()
  244. process.join()
  245. logger.debug("Connection handlers terminated")
  246. if self.experts:
  247. self.dht_handler_thread.stop.set()
  248. self.dht_handler_thread.join()
  249. if self.checkpoint_saver is not None:
  250. self.checkpoint_saver.stop.set()
  251. self.checkpoint_saver.join()
  252. self.dht.shutdown()
  253. self.dht.join()
  254. logger.debug(f"Shutting down runtime")
  255. self.runtime.shutdown()
  256. logger.info("Server shutdown succesfully")
  257. @contextmanager
  258. def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
  259. """A context manager that creates server in a background thread, awaits .ready on entry and shuts down on exit"""
  260. pipe, runners_pipe = mp.Pipe(duplex=True)
  261. runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
  262. try:
  263. runner.start()
  264. # once the server is ready, runner will send us
  265. # either (False, exception) or (True, PeerInfo(dht_peer_id, dht_maddrs))
  266. start_ok, data = pipe.recv()
  267. if start_ok:
  268. yield data
  269. pipe.send("SHUTDOWN") # on exit from context, send shutdown signal
  270. else:
  271. raise RuntimeError(f"Server failed to start: {data}")
  272. finally:
  273. runner.join(timeout=shutdown_timeout)
  274. if runner.is_alive():
  275. logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
  276. runner.kill()
  277. logger.info("Server terminated")
  278. def _server_runner(pipe, *args, **kwargs):
  279. try:
  280. server = Server.create(*args, start=True, **kwargs)
  281. except Exception as e:
  282. logger.exception(f"Encountered an exception when starting a server: {e}")
  283. pipe.send((False, f"{type(e).__name__} {e}"))
  284. return
  285. try:
  286. dht_maddrs = server.dht.get_visible_maddrs()
  287. pipe.send((True, PeerInfo(server.dht.peer_id, dht_maddrs)))
  288. pipe.recv() # wait for shutdown signal
  289. finally:
  290. logger.info("Shutting down server...")
  291. server.shutdown()
  292. server.join()
  293. logger.info("Server shut down")
  294. def _generate_uids(
  295. num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
  296. ) -> List[str]:
  297. """
  298. Sample experts from a given pattern, remove duplicates.
  299. :param num_experts: sample this many unique expert uids
  300. :param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
  301. means "sample random experts between myprefix.0.0 and myprefix.255.255;
  302. :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
  303. :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
  304. :note: this method is not strictly process-safe. If several servers run it concurrently, they have
  305. a small chance of sampling duplicate expert uids.
  306. """
  307. remaining_attempts = attempts_per_expert * num_experts
  308. found_uids, attempted_uids = list(), set()
  309. def _generate_uid():
  310. if expert_pattern is None:
  311. return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
  312. uid = []
  313. for block in expert_pattern.split(UID_DELIMITER):
  314. try:
  315. if "[" not in block and "]" not in block:
  316. uid.append(block)
  317. elif block.startswith("[") and block.endswith("]") and ":" in block:
  318. slice_start, slice_end = map(int, block[1:-1].split(":"))
  319. uid.append(str(random.randint(slice_start, slice_end - 1)))
  320. else:
  321. raise ValueError("Block must be either fixed or a range [from:to]")
  322. except KeyboardInterrupt:
  323. raise
  324. except Exception as e:
  325. raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
  326. return UID_DELIMITER.join(uid)
  327. while remaining_attempts > 0 and len(found_uids) < num_experts:
  328. # 1. sample new expert uids at random
  329. new_uids = []
  330. while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
  331. new_uid = _generate_uid()
  332. remaining_attempts -= 1
  333. if new_uid not in attempted_uids:
  334. attempted_uids.add(new_uid)
  335. new_uids.append(new_uid)
  336. # 2. look into DHT (if given) and remove duplicates
  337. if dht is not None:
  338. existing_expert_uids = {
  339. found_expert.uid for found_expert in get_experts(dht, new_uids) if found_expert is not None
  340. }
  341. new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
  342. found_uids += new_uids
  343. if len(found_uids) != num_experts:
  344. logger.warning(
  345. f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
  346. f"{attempts_per_expert * num_experts} attempts"
  347. )
  348. return found_uids