server.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. from __future__ import annotations
  2. import multiprocessing as mp
  3. import random
  4. import threading
  5. import time
  6. from typing import Dict, Optional, Sequence, Union
  7. import torch
  8. from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
  9. from hivemind.moe.server.layers import add_custom_models_from_file
  10. from hivemind.moe.server.runtime import Runtime
  11. from hivemind.proto.runtime_pb2 import CompressionType
  12. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  13. from src import BloomConfig, declare_active_modules
  14. from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
  15. from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
  16. from src.dht_utils import get_remote_module_infos
  17. from src.server.backend import TransformerBackend
  18. from src.server.block_selection import choose_best_blocks
  19. from src.server.cache import MemoryCache
  20. from src.server.handler import TransformerConnectionHandler
  21. from src.server.throughput import get_host_throughput
  22. from src.utils.convert_8bit import replace_8bit_linear
  23. use_hivemind_log_handler("in_root_logger")
  24. logger = get_logger(__file__)
  25. class Server(threading.Thread):
  26. """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
  27. def __init__(
  28. self,
  29. dht: DHT,
  30. module_backends: Dict[str, TransformerBackend],
  31. *,
  32. inference_max_length: int,
  33. num_connection_handlers: int = 8,
  34. throughput: float,
  35. update_period: float = 30,
  36. expiration: Optional[float] = None,
  37. start: bool,
  38. **kwargs,
  39. ):
  40. threading.Thread.__init__(self)
  41. self.dht, self.module_backends = dht, module_backends
  42. self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
  43. self.conn_handlers = [
  44. TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
  45. for _ in range(num_connection_handlers)
  46. ]
  47. self.runtime = Runtime(self.module_backends, **kwargs)
  48. self.dht_handler_thread = ModuleAnnouncerThread(
  49. self.module_backends,
  50. dht,
  51. throughput=throughput,
  52. update_period=update_period,
  53. expiration=expiration,
  54. daemon=True,
  55. )
  56. self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
  57. if start:
  58. self.run_in_background(await_ready=True)
  59. def run(self):
  60. """
  61. Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
  62. runs Runtime (self.runtime) to process incoming requests.
  63. """
  64. logger.info(f"Serving {len(self.module_backends)} blocks:")
  65. for expert_name, backend in self.module_backends.items():
  66. num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
  67. logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
  68. if not self.dht.is_alive():
  69. self.dht.run_in_background(await_ready=True)
  70. if self.module_backends:
  71. self.dht_handler_thread.start()
  72. if self.checkpoint_saver is not None:
  73. self.checkpoint_saver.start()
  74. for process in self.conn_handlers:
  75. if not process.is_alive():
  76. process.start()
  77. process.ready.result()
  78. try:
  79. self.runtime.run()
  80. finally:
  81. self.shutdown()
  82. # noinspection PyMethodOverriding
  83. @classmethod
  84. def create(
  85. cls,
  86. prefix: Optional[str],
  87. converted_model_name_or_path: str,
  88. throughput: Union[float, str],
  89. num_blocks: Optional[int] = None,
  90. block_indices: Optional[str] = None,
  91. num_handlers: int = 8,
  92. min_batch_size: int = 1,
  93. max_batch_size: int = 4096,
  94. inference_max_length: Optional[int] = None,
  95. torch_dtype: str = "auto",
  96. revision: str = "main",
  97. cache_dir: Optional[str] = None,
  98. attn_cache_size: Optional[int] = None,
  99. device: Optional[Union[str, torch.device]] = None,
  100. initial_peers: Sequence[str] = (),
  101. compression=CompressionType.NONE,
  102. stats_report_interval: Optional[int] = None,
  103. custom_module_path=None,
  104. update_period: float = 30,
  105. expiration: Optional[float] = None,
  106. max_block_selection_delay: float = 1,
  107. use_auth_token: Optional[str] = None,
  108. load_in_8bit: bool = False,
  109. *,
  110. start: bool,
  111. **kwargs,
  112. ) -> Server:
  113. """Create a server with one or more bloom blocks. See run_server.py for documentation."""
  114. if custom_module_path is not None:
  115. add_custom_models_from_file(custom_module_path)
  116. if prefix is None:
  117. prefix = converted_model_name_or_path
  118. assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
  119. f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
  120. f"Please specify --prefix manually when starting a server"
  121. )
  122. logger.info(f"Automatic dht prefix: {prefix}")
  123. assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
  124. if expiration is None:
  125. expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
  126. if inference_max_length is None:
  127. inference_max_length = max_batch_size
  128. dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
  129. visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
  130. logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
  131. device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  132. memory_cache = MemoryCache(device, attn_cache_size)
  133. assert isinstance(throughput, float) or throughput in ["auto", "eval"]
  134. if throughput in ["auto", "eval"]:
  135. throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
  136. if isinstance(torch_dtype, str):
  137. torch_dtype = DTYPE_MAP[torch_dtype]
  138. assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
  139. block_config = BloomConfig.from_pretrained(
  140. converted_model_name_or_path, use_auth_token=use_auth_token, revision=revision
  141. )
  142. if block_indices is not None:
  143. try:
  144. first_block_index, last_block_index = block_indices.split(":")
  145. first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
  146. except Exception as e:
  147. logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
  148. raise
  149. block_indices = range(first_block_index, last_block_index)
  150. else:
  151. # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
  152. # this delay decreases the probability of a race condition while choosing the best blocks to serve.
  153. time.sleep(random.random() * max_block_selection_delay)
  154. assert num_blocks is not None
  155. uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
  156. module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf"))
  157. block_indices = choose_best_blocks(num_blocks, module_infos)
  158. module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
  159. declare_active_modules(
  160. dht,
  161. module_uids,
  162. expiration_time=get_dht_time() + expiration,
  163. state=ServerState.JOINING,
  164. throughput=throughput,
  165. )
  166. logger.info(f"Announced that blocks {block_indices} are joining")
  167. blocks = {}
  168. for module_uid, block_index in zip(module_uids, block_indices):
  169. block = load_pretrained_block(
  170. converted_model_name_or_path,
  171. block_index,
  172. block_config,
  173. torch_dtype=torch_dtype,
  174. use_auth_token=use_auth_token,
  175. cache_dir=cache_dir,
  176. )
  177. if load_in_8bit:
  178. dtype = block.input_layernorm.weight.dtype
  179. assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
  180. block = replace_8bit_linear(block)
  181. block = block.to(device)
  182. for param in block.parameters():
  183. param.requires_grad = False
  184. blocks[module_uid] = TransformerBackend(
  185. module_uid,
  186. block,
  187. memory_cache=memory_cache,
  188. backend_dtype=None if torch_dtype == "auto" else torch_dtype,
  189. args_schema=(
  190. BatchTensorDescriptor(
  191. 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
  192. ),
  193. ),
  194. kwargs_schema={},
  195. outputs_schema=(
  196. BatchTensorDescriptor(
  197. 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
  198. ),
  199. ),
  200. min_batch_size=min_batch_size,
  201. max_batch_size=max_batch_size,
  202. )
  203. return cls(
  204. dht,
  205. blocks,
  206. throughput=throughput,
  207. num_connection_handlers=num_handlers,
  208. inference_max_length=inference_max_length,
  209. device=device,
  210. stats_report_interval=stats_report_interval,
  211. update_period=update_period,
  212. expiration=expiration,
  213. start=start,
  214. )
  215. def run_in_background(self, await_ready=True, timeout=None):
  216. """
  217. Starts Server in a background thread. if await_ready, this method will wait until background server
  218. is ready to process incoming requests or for :timeout: seconds max.
  219. """
  220. self.start()
  221. if await_ready and not self.ready.wait(timeout=timeout):
  222. raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
  223. @property
  224. def ready(self) -> mp.synchronize.Event:
  225. """
  226. An event (multiprocessing.Event) that is set when the server is ready to process requests.
  227. Example
  228. =======
  229. >>> server.start()
  230. >>> server.ready.wait(timeout=10)
  231. >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
  232. """
  233. return self.runtime.ready # mp.Event that is true if self is ready to process batches
  234. def shutdown(self):
  235. """
  236. Gracefully terminate the server, process-safe.
  237. Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
  238. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
  239. """
  240. if self.module_backends:
  241. declare_active_modules(
  242. self.dht,
  243. self.module_backends.keys(),
  244. expiration_time=get_dht_time() + self.expiration,
  245. state=ServerState.OFFLINE,
  246. throughput=self.throughput,
  247. )
  248. logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
  249. self.ready.clear()
  250. for process in self.conn_handlers:
  251. process.terminate()
  252. process.join()
  253. logger.debug("Connection handlers terminated")
  254. if self.module_backends:
  255. self.dht_handler_thread.stop.set()
  256. self.dht_handler_thread.join()
  257. if self.checkpoint_saver is not None:
  258. self.checkpoint_saver.stop.set()
  259. self.checkpoint_saver.join()
  260. self.dht.shutdown()
  261. self.dht.join()
  262. logger.debug(f"Shutting down runtime")
  263. self.runtime.shutdown()
  264. logger.info("Server shut down succesfully")
  265. class ModuleAnnouncerThread(threading.Thread):
  266. """Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
  267. def __init__(
  268. self,
  269. module_backends: Dict[str, TransformerBackend],
  270. dht: DHT,
  271. *,
  272. throughput: float,
  273. update_period: float = 30,
  274. expiration: float,
  275. **kwargs,
  276. ):
  277. super().__init__(**kwargs)
  278. self.module_backends = module_backends
  279. self.dht = dht
  280. self.throughput = throughput
  281. self.update_period = update_period
  282. self.expiration = expiration
  283. self.stop = threading.Event()
  284. def run(self) -> None:
  285. while True:
  286. declare_active_modules(
  287. self.dht,
  288. self.module_backends.keys(),
  289. expiration_time=get_dht_time() + self.expiration,
  290. state=ServerState.ONLINE,
  291. throughput=self.throughput,
  292. )
  293. if self.stop.wait(self.update_period):
  294. break