server.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. from __future__ import annotations
  2. import multiprocessing as mp
  3. import random
  4. import threading
  5. import time
  6. from typing import Dict, Optional, Sequence, Union
  7. import torch
  8. from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
  9. from hivemind.moe.server.layers import add_custom_models_from_file
  10. from hivemind.moe.server.runtime import Runtime
  11. from hivemind.proto.runtime_pb2 import CompressionType
  12. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  13. from src import BloomConfig, declare_active_modules
  14. from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
  15. from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
  16. from src.dht_utils import get_remote_module_infos
  17. from src.server.backend import TransformerBackend
  18. from src.server.block_selection import choose_best_blocks
  19. from src.server.cache import MemoryCache
  20. from src.server.handler import TransformerConnectionHandler
  21. from src.server.throughput import get_host_throughput
  22. use_hivemind_log_handler("in_root_logger")
  23. logger = get_logger(__file__)
  24. class Server(threading.Thread):
  25. """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
  26. def __init__(
  27. self,
  28. dht: DHT,
  29. module_backends: Dict[str, TransformerBackend],
  30. *,
  31. device: torch.device,
  32. num_connection_handlers: int = 8,
  33. throughput: float,
  34. update_period: float = 30,
  35. expiration: Optional[float] = None,
  36. start: bool,
  37. **kwargs,
  38. ):
  39. threading.Thread.__init__(self)
  40. self.dht, self.module_backends = dht, module_backends
  41. self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
  42. self.conn_handlers = [
  43. TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
  44. ]
  45. self.runtime = Runtime(self.module_backends, device=device, **kwargs)
  46. self.dht_handler_thread = ModuleAnnouncerThread(
  47. self.module_backends,
  48. dht,
  49. throughput=throughput,
  50. update_period=update_period,
  51. expiration=expiration,
  52. daemon=True,
  53. )
  54. self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
  55. if start:
  56. self.run_in_background(await_ready=True)
  57. def run(self):
  58. """
  59. Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
  60. runs Runtime (self.runtime) to process incoming requests.
  61. """
  62. logger.info(f"Serving {len(self.module_backends)} blocks:")
  63. for expert_name, backend in self.module_backends.items():
  64. num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
  65. logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
  66. if not self.dht.is_alive():
  67. self.dht.run_in_background(await_ready=True)
  68. if self.module_backends:
  69. self.dht_handler_thread.start()
  70. if self.checkpoint_saver is not None:
  71. self.checkpoint_saver.start()
  72. for process in self.conn_handlers:
  73. if not process.is_alive():
  74. process.start()
  75. process.ready.result()
  76. try:
  77. self.runtime.run()
  78. finally:
  79. self.shutdown()
  80. # noinspection PyMethodOverriding
  81. @classmethod
  82. def create(
  83. cls,
  84. prefix: Optional[str],
  85. converted_model_name_or_path: str,
  86. throughput: Union[float, str],
  87. num_blocks: Optional[int] = None,
  88. block_indices: Optional[str] = None,
  89. num_handlers: Optional[int] = None,
  90. min_batch_size: int = 1,
  91. max_batch_size: int = 4096,
  92. torch_dtype: str = "auto",
  93. cache_size_bytes: Optional[int] = None,
  94. device: Optional[Union[str, torch.device]] = None,
  95. initial_peers: Sequence[str] = (),
  96. compression=CompressionType.NONE,
  97. stats_report_interval: Optional[int] = None,
  98. custom_module_path=None,
  99. update_period: float = 30,
  100. expiration: Optional[float] = None,
  101. max_block_selection_delay: float = 1,
  102. use_auth_token: Optional[str] = None,
  103. *,
  104. start: bool,
  105. **kwargs,
  106. ) -> Server:
  107. """Create a server with one or more bloom blocks. See run_server.py for documentation."""
  108. if custom_module_path is not None:
  109. add_custom_models_from_file(custom_module_path)
  110. if prefix is None:
  111. prefix = converted_model_name_or_path
  112. assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
  113. f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
  114. f"Please specify --prefix manually when starting a server"
  115. )
  116. logger.info(f"Automatic dht prefix: {prefix}")
  117. assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
  118. if expiration is None:
  119. expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
  120. dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
  121. visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
  122. logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
  123. device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  124. memory_cache = MemoryCache(device, cache_size_bytes)
  125. assert isinstance(throughput, float) or throughput in ["auto", "eval"]
  126. if throughput in ["auto", "eval"]:
  127. throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
  128. if isinstance(torch_dtype, str):
  129. torch_dtype = DTYPE_MAP[torch_dtype]
  130. assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
  131. block_config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
  132. if block_indices is not None:
  133. try:
  134. first_block_index, last_block_index = block_indices.split(":")
  135. first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
  136. except Exception as e:
  137. logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
  138. raise
  139. block_indices = range(first_block_index, last_block_index)
  140. else:
  141. # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
  142. # this delay decreases the probability of a race condition while choosing the best blocks to serve.
  143. time.sleep(random.random() * max_block_selection_delay)
  144. assert num_blocks is not None
  145. uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
  146. module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf"))
  147. block_indices = choose_best_blocks(num_blocks, module_infos)
  148. module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
  149. declare_active_modules(
  150. dht,
  151. module_uids,
  152. expiration_time=get_dht_time() + expiration,
  153. state=ServerState.JOINING,
  154. throughput=throughput,
  155. )
  156. logger.info(f"Announced that blocks {block_indices} are joining")
  157. blocks = {}
  158. for module_uid, block_index in zip(module_uids, block_indices):
  159. block = load_pretrained_block(
  160. converted_model_name_or_path,
  161. block_index,
  162. block_config,
  163. torch_dtype=torch_dtype,
  164. use_auth_token=use_auth_token,
  165. )
  166. for param in block.parameters():
  167. param.requires_grad = False
  168. blocks[module_uid] = TransformerBackend(
  169. module_uid,
  170. block,
  171. memory_cache=memory_cache,
  172. backend_dtype=None if torch_dtype == "auto" else torch_dtype,
  173. args_schema=(
  174. BatchTensorDescriptor(
  175. 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
  176. ),
  177. ),
  178. kwargs_schema={},
  179. outputs_schema=(
  180. BatchTensorDescriptor(
  181. 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
  182. ),
  183. ),
  184. min_batch_size=min_batch_size,
  185. max_batch_size=max_batch_size,
  186. )
  187. return cls(
  188. dht,
  189. blocks,
  190. throughput=throughput,
  191. num_connection_handlers=num_handlers,
  192. device=device,
  193. stats_report_interval=stats_report_interval,
  194. update_period=update_period,
  195. expiration=expiration,
  196. start=start,
  197. )
  198. def run_in_background(self, await_ready=True, timeout=None):
  199. """
  200. Starts Server in a background thread. if await_ready, this method will wait until background server
  201. is ready to process incoming requests or for :timeout: seconds max.
  202. """
  203. self.start()
  204. if await_ready and not self.ready.wait(timeout=timeout):
  205. raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
  206. @property
  207. def ready(self) -> mp.synchronize.Event:
  208. """
  209. An event (multiprocessing.Event) that is set when the server is ready to process requests.
  210. Example
  211. =======
  212. >>> server.start()
  213. >>> server.ready.wait(timeout=10)
  214. >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
  215. """
  216. return self.runtime.ready # mp.Event that is true if self is ready to process batches
  217. def shutdown(self):
  218. """
  219. Gracefully terminate the server, process-safe.
  220. Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
  221. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
  222. """
  223. if self.module_backends:
  224. declare_active_modules(
  225. self.dht,
  226. self.module_backends.keys(),
  227. expiration_time=get_dht_time() + self.expiration,
  228. state=ServerState.OFFLINE,
  229. throughput=self.throughput,
  230. )
  231. logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
  232. self.ready.clear()
  233. for process in self.conn_handlers:
  234. process.terminate()
  235. process.join()
  236. logger.debug("Connection handlers terminated")
  237. if self.module_backends:
  238. self.dht_handler_thread.stop.set()
  239. self.dht_handler_thread.join()
  240. if self.checkpoint_saver is not None:
  241. self.checkpoint_saver.stop.set()
  242. self.checkpoint_saver.join()
  243. self.dht.shutdown()
  244. self.dht.join()
  245. logger.debug(f"Shutting down runtime")
  246. self.runtime.shutdown()
  247. logger.info("Server shut down succesfully")
  248. class ModuleAnnouncerThread(threading.Thread):
  249. """Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
  250. def __init__(
  251. self,
  252. module_backends: Dict[str, TransformerBackend],
  253. dht: DHT,
  254. *,
  255. throughput: float,
  256. update_period: float = 30,
  257. expiration: float,
  258. **kwargs,
  259. ):
  260. super().__init__(**kwargs)
  261. self.module_backends = module_backends
  262. self.dht = dht
  263. self.throughput = throughput
  264. self.update_period = update_period
  265. self.expiration = expiration
  266. self.stop = threading.Event()
  267. def run(self) -> None:
  268. while True:
  269. declare_active_modules(
  270. self.dht,
  271. self.module_backends.keys(),
  272. expiration_time=get_dht_time() + self.expiration,
  273. state=ServerState.ONLINE,
  274. throughput=self.throughput,
  275. )
  276. if self.stop.wait(self.update_period):
  277. break