server.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. from __future__ import annotations
  2. import multiprocessing as mp
  3. import random
  4. import threading
  5. from contextlib import contextmanager
  6. from functools import partial
  7. from pathlib import Path
  8. from typing import Dict, List, Optional, Tuple
  9. import torch
  10. from multiaddr import Multiaddr
  11. from hivemind.dht import DHT
  12. from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
  13. from hivemind.moe.server.connection_handler import ConnectionHandler
  14. from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts
  15. from hivemind.moe.server.expert_backend import ExpertBackend
  16. from hivemind.moe.server.expert_uid import UID_DELIMITER
  17. from hivemind.moe.server.layers import (
  18. add_custom_models_from_file,
  19. name_to_block,
  20. name_to_input,
  21. schedule_name_to_scheduler,
  22. )
  23. from hivemind.moe.server.runtime import Runtime
  24. from hivemind.p2p import PeerInfo
  25. from hivemind.proto.runtime_pb2 import CompressionType
  26. from hivemind.utils.logging import get_logger
  27. from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
  28. logger = get_logger(__name__)
  29. class Server(threading.Thread):
  30. """
  31. Server allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
  32. After creation, a server should be started: see Server.run or Server.run_in_background.
  33. A working server does two things:
  34. - processes incoming forward/backward requests via Runtime (created by the server)
  35. - publishes updates to expert status every :update_period: seconds
  36. :type dht: an instance of hivemind.DHT.
  37. :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
  38. :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
  39. :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
  40. if too small for normal functioning, we recommend 4 handlers per expert backend.
  41. :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
  42. if dht is None, this parameter is ignored.
  43. :param start: if True, the server will immediately start as a background thread and returns control after server
  44. is ready (see .ready below)
  45. """
  46. def __init__(
  47. self,
  48. dht: DHT,
  49. expert_backends: Dict[str, ExpertBackend],
  50. num_connection_handlers: int = 1,
  51. update_period: int = 30,
  52. start=False,
  53. checkpoint_dir=None,
  54. **kwargs,
  55. ):
  56. super().__init__()
  57. self.dht, self.experts, self.update_period = dht, expert_backends, update_period
  58. self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(num_connection_handlers)]
  59. if checkpoint_dir is not None:
  60. self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
  61. else:
  62. self.checkpoint_saver = None
  63. self.runtime = Runtime(self.experts, **kwargs)
  64. if self.experts:
  65. self.dht_handler_thread = DHTHandlerThread(
  66. experts=self.experts,
  67. dht=self.dht,
  68. peer_id=self.dht.peer_id,
  69. update_period=self.update_period,
  70. daemon=True,
  71. )
  72. if start:
  73. self.run_in_background(await_ready=True)
  74. @classmethod
  75. def create(
  76. cls,
  77. num_experts: int = None,
  78. expert_uids: str = None,
  79. expert_pattern: str = None,
  80. expert_cls="ffn",
  81. hidden_dim=1024,
  82. optim_cls=torch.optim.Adam,
  83. scheduler: str = "none",
  84. num_warmup_steps=None,
  85. num_total_steps=None,
  86. clip_grad_norm=None,
  87. num_handlers=None,
  88. min_batch_size=1,
  89. max_batch_size=4096,
  90. device=None,
  91. initial_peers=(),
  92. checkpoint_dir: Optional[Path] = None,
  93. compression=CompressionType.NONE,
  94. stats_report_interval: Optional[int] = None,
  95. custom_module_path=None,
  96. *,
  97. start: bool,
  98. **kwargs,
  99. ) -> Server:
  100. """
  101. Instantiate a server with several identical experts. See argparse comments below for details
  102. :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
  103. :param num_experts: run this many identical experts
  104. :param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
  105. means "sample random experts between myprefix.0.0 and myprefix.255.255;
  106. :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
  107. :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
  108. :param hidden_dim: main dimension for expert_cls
  109. :param num_handlers: server will use this many parallel processes to handle incoming requests
  110. :param min_batch_size: total num examples in the same batch will be greater than this value
  111. :param max_batch_size: total num examples in the same batch will not exceed this value
  112. :param device: all experts will use this device in torch notation; default: cuda if available else cpu
  113. :param optim_cls: uses this optimizer to train all experts
  114. :param scheduler: if not `none`, the name of the expert LR scheduler
  115. :param num_warmup_steps: the number of warmup steps for LR schedule
  116. :param num_total_steps: the total number of steps for LR schedule
  117. :param clip_grad_norm: maximum gradient norm used for clipping
  118. :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
  119. :param checkpoint_dir: directory to save and load expert checkpoints
  120. :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
  121. hosted on this server. For a more fine-grained compression, start server in python and specify compression
  122. for each BatchTensorProto in ExpertBackend for the respective experts.
  123. :param start: if True, starts server right away and returns when server is ready for requests
  124. :param stats_report_interval: interval between two reports of batch processing performance statistics
  125. :param kwargs: any other params will be forwarded to DHT upon creation
  126. """
  127. if custom_module_path is not None:
  128. add_custom_models_from_file(custom_module_path)
  129. assert expert_cls in name_to_block
  130. dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
  131. visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
  132. logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
  133. assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
  134. num_experts is not None and expert_uids is None
  135. ), "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
  136. if expert_uids is None:
  137. if checkpoint_dir is not None:
  138. assert is_directory(checkpoint_dir)
  139. expert_uids = [
  140. child.name for child in checkpoint_dir.iterdir() if (child / "checkpoint_last.pt").exists()
  141. ]
  142. total_experts_in_checkpoint = len(expert_uids)
  143. logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
  144. if total_experts_in_checkpoint > num_experts:
  145. raise ValueError(
  146. f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
  147. f"which is smaller. Either increase num_experts or remove unneeded checkpoints."
  148. )
  149. else:
  150. expert_uids = []
  151. uids_to_generate = num_experts - len(expert_uids)
  152. if uids_to_generate > 0:
  153. logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
  154. expert_uids.extend(_generate_uids(uids_to_generate, expert_pattern, dht))
  155. num_experts = len(expert_uids)
  156. num_handlers = num_handlers if num_handlers is not None else num_experts * 8
  157. optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
  158. device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  159. sample_input = name_to_input[expert_cls](DUMMY_BATCH_SIZE, hidden_dim)
  160. if isinstance(sample_input, tuple):
  161. args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
  162. else:
  163. args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
  164. scheduler = schedule_name_to_scheduler[scheduler]
  165. # initialize experts
  166. experts = {}
  167. for expert_uid in expert_uids:
  168. expert = name_to_block[expert_cls](hidden_dim)
  169. experts[expert_uid] = ExpertBackend(
  170. name=expert_uid,
  171. expert=expert,
  172. args_schema=args_schema,
  173. optimizer=optim_cls(expert.parameters()),
  174. scheduler=scheduler,
  175. num_warmup_steps=num_warmup_steps,
  176. num_total_steps=num_total_steps,
  177. clip_grad_norm=clip_grad_norm,
  178. min_batch_size=min_batch_size,
  179. max_batch_size=max_batch_size,
  180. )
  181. if checkpoint_dir is not None:
  182. load_experts(experts, checkpoint_dir)
  183. return cls(
  184. dht,
  185. experts,
  186. num_connection_handlers=num_handlers,
  187. device=device,
  188. checkpoint_dir=checkpoint_dir,
  189. stats_report_interval=stats_report_interval,
  190. start=start,
  191. )
  192. def run(self):
  193. """
  194. Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
  195. runs Runtime (self.runtime) to process incoming requests.
  196. """
  197. logger.info(f"Server started with {len(self.experts)} experts:")
  198. for expert_name, backend in self.experts.items():
  199. num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
  200. logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
  201. if not self.dht.is_alive():
  202. self.dht.run_in_background(await_ready=True)
  203. if self.experts:
  204. self.dht_handler_thread.start()
  205. if self.checkpoint_saver is not None:
  206. self.checkpoint_saver.start()
  207. for process in self.conn_handlers:
  208. if not process.is_alive():
  209. process.start()
  210. process.ready.result()
  211. try:
  212. self.runtime.run()
  213. finally:
  214. self.shutdown()
  215. def run_in_background(self, await_ready=True, timeout=None):
  216. """
  217. Starts Server in a background thread. if await_ready, this method will wait until background server
  218. is ready to process incoming requests or for :timeout: seconds max.
  219. """
  220. self.start()
  221. if await_ready and not self.ready.wait(timeout=timeout):
  222. raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
  223. @property
  224. def ready(self) -> mp.synchronize.Event:
  225. """
  226. An event (multiprocessing.Event) that is set when the server is ready to process requests.
  227. Example
  228. =======
  229. >>> server.start()
  230. >>> server.ready.wait(timeout=10)
  231. >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
  232. """
  233. return self.runtime.ready # mp.Event that is true if self is ready to process batches
  234. def shutdown(self):
  235. """
  236. Gracefully terminate the server, process-safe.
  237. Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
  238. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
  239. """
  240. self.ready.clear()
  241. for process in self.conn_handlers:
  242. process.terminate()
  243. process.join()
  244. logger.debug("Connection handlers terminated")
  245. if self.experts:
  246. self.dht_handler_thread.stop.set()
  247. self.dht_handler_thread.join()
  248. if self.checkpoint_saver is not None:
  249. self.checkpoint_saver.stop.set()
  250. self.checkpoint_saver.join()
  251. self.dht.shutdown()
  252. self.dht.join()
  253. logger.debug(f"Shutting down runtime")
  254. self.runtime.shutdown()
  255. logger.info("Server shutdown succesfully")
  256. @contextmanager
  257. def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
  258. """A context manager that creates server in a background thread, awaits .ready on entry and shuts down on exit"""
  259. pipe, runners_pipe = mp.Pipe(duplex=True)
  260. runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
  261. try:
  262. runner.start()
  263. # once the server is ready, runner will send us
  264. # either (False, exception) or (True, PeerInfo(dht_peer_id, dht_maddrs))
  265. start_ok, data = pipe.recv()
  266. if start_ok:
  267. yield data
  268. pipe.send("SHUTDOWN") # on exit from context, send shutdown signal
  269. else:
  270. raise RuntimeError(f"Server failed to start: {data}")
  271. finally:
  272. runner.join(timeout=shutdown_timeout)
  273. if runner.is_alive():
  274. logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
  275. runner.kill()
  276. logger.info("Server terminated")
  277. def _server_runner(pipe, *args, **kwargs):
  278. try:
  279. server = Server.create(*args, start=True, **kwargs)
  280. except Exception as e:
  281. logger.exception(f"Encountered an exception when starting a server: {e}")
  282. pipe.send((False, f"{type(e).__name__} {e}"))
  283. return
  284. try:
  285. dht_maddrs = server.dht.get_visible_maddrs()
  286. pipe.send((True, PeerInfo(server.dht.peer_id, dht_maddrs)))
  287. pipe.recv() # wait for shutdown signal
  288. finally:
  289. logger.info("Shutting down server...")
  290. server.shutdown()
  291. server.join()
  292. logger.info("Server shut down")
  293. def _generate_uids(
  294. num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
  295. ) -> List[str]:
  296. """
  297. Sample experts from a given pattern, remove duplicates.
  298. :param num_experts: sample this many unique expert uids
  299. :param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
  300. means "sample random experts between myprefix.0.0 and myprefix.255.255;
  301. :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
  302. :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
  303. :note: this method is not strictly process-safe. If several servers run it concurrently, they have
  304. a small chance of sampling duplicate expert uids.
  305. """
  306. remaining_attempts = attempts_per_expert * num_experts
  307. found_uids, attempted_uids = list(), set()
  308. def _generate_uid():
  309. if expert_pattern is None:
  310. return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
  311. uid = []
  312. for block in expert_pattern.split(UID_DELIMITER):
  313. try:
  314. if "[" not in block and "]" not in block:
  315. uid.append(block)
  316. elif block.startswith("[") and block.endswith("]") and ":" in block:
  317. slice_start, slice_end = map(int, block[1:-1].split(":"))
  318. uid.append(str(random.randint(slice_start, slice_end - 1)))
  319. else:
  320. raise ValueError("Block must be either fixed or a range [from:to]")
  321. except KeyboardInterrupt:
  322. raise
  323. except Exception as e:
  324. raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
  325. return UID_DELIMITER.join(uid)
  326. while remaining_attempts > 0 and len(found_uids) < num_experts:
  327. # 1. sample new expert uids at random
  328. new_uids = []
  329. while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
  330. new_uid = _generate_uid()
  331. remaining_attempts -= 1
  332. if new_uid not in attempted_uids:
  333. attempted_uids.add(new_uid)
  334. new_uids.append(new_uid)
  335. # 2. look into DHT (if given) and remove duplicates
  336. if dht is not None:
  337. existing_expert_uids = {
  338. found_expert.uid for found_expert in get_experts(dht, new_uids) if found_expert is not None
  339. }
  340. new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
  341. found_uids += new_uids
  342. if len(found_uids) != num_experts:
  343. logger.warning(
  344. f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
  345. f"{attempts_per_expert * num_experts} attempts"
  346. )
  347. return found_uids