server.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. from __future__ import annotations
  2. import gc
  3. import multiprocessing as mp
  4. import random
  5. import threading
  6. import time
  7. from typing import Dict, List, Optional, Union
  8. import numpy as np
  9. import psutil
  10. import torch
  11. from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
  12. from hivemind.moe.server.layers import add_custom_models_from_file
  13. from hivemind.moe.server.runtime import Runtime
  14. from hivemind.proto.runtime_pb2 import CompressionType
  15. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  16. from src import BloomConfig, declare_active_modules
  17. from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
  18. from src.constants import PUBLIC_INITIAL_PEERS
  19. from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
  20. from src.dht_utils import get_remote_module_infos
  21. from src.server import block_selection
  22. from src.server.backend import TransformerBackend
  23. from src.server.cache import MemoryCache
  24. from src.server.handler import TransformerConnectionHandler
  25. from src.server.throughput import get_host_throughput
  26. from src.utils.convert_8bit import replace_8bit_linear
  27. use_hivemind_log_handler("in_root_logger")
  28. logger = get_logger(__file__)
  29. class Server:
  30. """
  31. Runs ModuleContainer, periodically checks that the network is balanced,
  32. restarts the ModuleContainer with other layers if the imbalance is significant
  33. """
  34. def __init__(
  35. self,
  36. *,
  37. initial_peers: List[str],
  38. prefix: Optional[str],
  39. converted_model_name_or_path: str,
  40. throughput: Union[float, str],
  41. num_blocks: Optional[int] = None,
  42. block_indices: Optional[str] = None,
  43. num_handlers: int = 8,
  44. min_batch_size: int = 1,
  45. max_batch_size: int = 4096,
  46. inference_max_length: int = 4096,
  47. torch_dtype: str = "auto",
  48. revision: str = "main",
  49. cache_dir: Optional[str] = None,
  50. attn_cache_size: Optional[int] = None,
  51. device: Optional[Union[str, torch.device]] = None,
  52. compression=CompressionType.NONE,
  53. stats_report_interval: Optional[int] = None,
  54. custom_module_path=None,
  55. update_period: float = 30,
  56. expiration: Optional[float] = None,
  57. prefetch_batches: int = 1,
  58. sender_threads: int = 1,
  59. balance_quality: float = 0.75,
  60. mean_balance_check_period: float = 60,
  61. mean_block_selection_delay: float = 0.5,
  62. use_auth_token: Optional[str] = None,
  63. load_in_8bit: bool = False,
  64. **kwargs,
  65. ):
  66. """Create a server with one or more bloom blocks. See run_server.py for documentation."""
  67. self.converted_model_name_or_path = converted_model_name_or_path
  68. self.num_handlers = num_handlers
  69. self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
  70. self.inference_max_length = inference_max_length
  71. self.cache_dir = cache_dir
  72. self.attn_cache_size = attn_cache_size
  73. self.compression = compression
  74. self.stats_report_interval, self.update_period = stats_report_interval, update_period
  75. self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
  76. self.use_auth_token = use_auth_token
  77. self.load_in_8bit = load_in_8bit
  78. if custom_module_path is not None:
  79. add_custom_models_from_file(custom_module_path)
  80. if prefix is None:
  81. prefix = converted_model_name_or_path
  82. assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
  83. f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
  84. f"Please specify --prefix manually when starting a server"
  85. )
  86. logger.info(f"Automatic dht prefix: {prefix}")
  87. self.prefix = prefix
  88. if expiration is None:
  89. expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
  90. self.expiration = expiration
  91. self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
  92. visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
  93. if initial_peers == PUBLIC_INITIAL_PEERS:
  94. logger.info("Connecting to the public Petals swarm")
  95. else:
  96. logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
  97. device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  98. self.device = device
  99. self.memory_cache = MemoryCache(device, attn_cache_size)
  100. assert isinstance(throughput, float) or throughput in ["auto", "eval"]
  101. if throughput in ["auto", "eval"]:
  102. throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
  103. self.throughput = throughput
  104. if isinstance(torch_dtype, str):
  105. torch_dtype = DTYPE_MAP[torch_dtype]
  106. assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
  107. self.torch_dtype = torch_dtype
  108. self.block_config = BloomConfig.from_pretrained(
  109. converted_model_name_or_path,
  110. use_auth_token=use_auth_token,
  111. revision=revision,
  112. )
  113. self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
  114. assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
  115. if block_indices is not None:
  116. try:
  117. first_block_index, last_block_index = block_indices.split(":")
  118. first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
  119. except Exception as e:
  120. logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
  121. raise
  122. block_indices = range(first_block_index, last_block_index)
  123. self.strict_block_indices, self.num_blocks = block_indices, num_blocks
  124. self.balance_quality = balance_quality
  125. self.mean_balance_check_period = mean_balance_check_period
  126. self.mean_block_selection_delay = mean_block_selection_delay
  127. self.stop = threading.Event()
  128. def run(self):
  129. while True:
  130. block_indices = self._choose_blocks()
  131. self.module_container = ModuleContainer.create(
  132. dht=self.dht,
  133. prefix=self.prefix,
  134. converted_model_name_or_path=self.converted_model_name_or_path,
  135. block_config=self.block_config,
  136. memory_cache=self.memory_cache,
  137. throughput=self.throughput,
  138. block_indices=block_indices,
  139. num_handlers=self.num_handlers,
  140. min_batch_size=self.min_batch_size,
  141. max_batch_size=self.max_batch_size,
  142. inference_max_length=self.inference_max_length,
  143. torch_dtype=self.torch_dtype,
  144. cache_dir=self.cache_dir,
  145. device=self.device,
  146. compression=self.compression,
  147. stats_report_interval=self.stats_report_interval,
  148. update_period=self.update_period,
  149. expiration=self.expiration,
  150. prefetch_batches=self.prefetch_batches,
  151. sender_threads=self.sender_threads,
  152. use_auth_token=self.use_auth_token,
  153. load_in_8bit=self.load_in_8bit,
  154. start=True,
  155. )
  156. try:
  157. self.module_container.ready.wait()
  158. while True:
  159. timeout = random.random() * 2 * self.mean_balance_check_period
  160. # TODO: Follow ModuleContainer status (to restart/stop if it crashes)
  161. if self.stop.wait(timeout):
  162. return
  163. if self._should_choose_other_blocks():
  164. logger.info("Swarm is imbalanced, server will load other blocks")
  165. break # Stop serving this set of modules
  166. finally:
  167. self.module_container.shutdown()
  168. self._clean_memory_and_fds()
  169. def _clean_memory_and_fds(self):
  170. del self.module_container
  171. gc.collect() # In particular, this closes unused file descriptors
  172. cur_proc = psutil.Process()
  173. num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)]
  174. logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left")
  175. def _choose_blocks(self) -> List[int]:
  176. if self.strict_block_indices is not None:
  177. return self.strict_block_indices
  178. assert self.num_blocks is not None
  179. # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
  180. # this delay decreases the probability of a race condition while choosing the best blocks to serve.
  181. time.sleep(random.random() * 2 * self.mean_block_selection_delay)
  182. module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
  183. return block_selection.choose_best_blocks(self.num_blocks, module_infos)
  184. def _should_choose_other_blocks(self) -> bool:
  185. if self.strict_block_indices is not None:
  186. return False
  187. module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
  188. return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
  189. def shutdown(self):
  190. self.stop.set()
  191. self.dht.shutdown()
  192. self.dht.join()
  193. class ModuleContainer(threading.Thread):
  194. """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
  195. # noinspection PyMethodOverriding
  196. @classmethod
  197. def create(
  198. cls,
  199. *,
  200. dht: DHT,
  201. prefix: str,
  202. converted_model_name_or_path: str,
  203. block_config: BloomConfig,
  204. memory_cache: MemoryCache,
  205. throughput: float,
  206. block_indices: List[int],
  207. num_handlers: Optional[int],
  208. min_batch_size: int,
  209. max_batch_size: int,
  210. inference_max_length: int,
  211. torch_dtype: torch.dtype,
  212. cache_dir: Optional[str],
  213. device: Union[str, torch.device],
  214. compression: CompressionType,
  215. stats_report_interval: Optional[int],
  216. update_period: float,
  217. expiration: Optional[float],
  218. prefetch_batches: int,
  219. sender_threads: int,
  220. use_auth_token: Optional[str],
  221. load_in_8bit: bool,
  222. start: bool,
  223. ) -> ModuleContainer:
  224. module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
  225. joining_announcer = ModuleAnnouncerThread(
  226. module_uids,
  227. dht,
  228. ServerState.JOINING,
  229. throughput=throughput,
  230. update_period=update_period,
  231. expiration=expiration,
  232. daemon=True,
  233. )
  234. joining_announcer.start()
  235. logger.info(f"Announced that blocks {block_indices} are joining")
  236. try:
  237. blocks = {}
  238. for module_uid, block_index in zip(module_uids, block_indices):
  239. block = load_pretrained_block(
  240. converted_model_name_or_path,
  241. block_index,
  242. block_config,
  243. torch_dtype=torch_dtype,
  244. use_auth_token=use_auth_token,
  245. cache_dir=cache_dir,
  246. )
  247. if load_in_8bit:
  248. dtype = block.input_layernorm.weight.dtype
  249. block = replace_8bit_linear(block)
  250. block = block.to(device)
  251. for param in block.parameters():
  252. param.requires_grad = False
  253. blocks[module_uid] = TransformerBackend(
  254. module_uid,
  255. block,
  256. memory_cache=memory_cache,
  257. backend_dtype=None if torch_dtype == "auto" else torch_dtype,
  258. args_schema=(
  259. BatchTensorDescriptor(
  260. 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
  261. ),
  262. ),
  263. kwargs_schema={},
  264. outputs_schema=(
  265. BatchTensorDescriptor(
  266. 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
  267. ),
  268. ),
  269. min_batch_size=min_batch_size,
  270. max_batch_size=max_batch_size,
  271. )
  272. except:
  273. joining_announcer.stop.set()
  274. joining_announcer.join()
  275. declare_active_modules(
  276. dht,
  277. module_uids,
  278. expiration_time=get_dht_time() + expiration,
  279. state=ServerState.OFFLINE,
  280. throughput=throughput,
  281. )
  282. logger.info(f"Announced that blocks {module_uids} are offline")
  283. raise
  284. else:
  285. joining_announcer.stop.set()
  286. joining_announcer.join()
  287. return cls(
  288. dht,
  289. blocks,
  290. throughput=throughput,
  291. num_connection_handlers=num_handlers,
  292. inference_max_length=inference_max_length,
  293. device=device,
  294. stats_report_interval=stats_report_interval,
  295. update_period=update_period,
  296. expiration=expiration,
  297. prefetch_batches=prefetch_batches,
  298. sender_threads=sender_threads,
  299. start=start,
  300. )
  301. def __init__(
  302. self,
  303. dht: DHT,
  304. module_backends: Dict[str, TransformerBackend],
  305. *,
  306. inference_max_length: int,
  307. num_connection_handlers: int,
  308. throughput: float,
  309. update_period: float,
  310. expiration: Optional[float] = None,
  311. start: bool,
  312. **kwargs,
  313. ):
  314. super().__init__()
  315. self.dht, self.module_backends = dht, module_backends
  316. self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
  317. self.conn_handlers = [
  318. TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
  319. for _ in range(num_connection_handlers)
  320. ]
  321. self.runtime = Runtime(self.module_backends, **kwargs)
  322. self.online_announcer = ModuleAnnouncerThread(
  323. list(self.module_backends.keys()),
  324. dht,
  325. ServerState.ONLINE,
  326. throughput=throughput,
  327. update_period=update_period,
  328. expiration=expiration,
  329. daemon=True,
  330. )
  331. self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
  332. if start:
  333. self.run_in_background(await_ready=True)
  334. def run(self):
  335. """
  336. Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
  337. runs Runtime (self.runtime) to process incoming requests.
  338. """
  339. logger.info(f"Serving {len(self.module_backends)} blocks:")
  340. for expert_name, backend in self.module_backends.items():
  341. num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
  342. logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
  343. if not self.dht.is_alive():
  344. self.dht.run_in_background(await_ready=True)
  345. self.online_announcer.start()
  346. if self.checkpoint_saver is not None:
  347. self.checkpoint_saver.start()
  348. for handler in self.conn_handlers:
  349. handler.run_in_background()
  350. self.runtime.run()
  351. def run_in_background(self, await_ready=True, timeout=None):
  352. """
  353. Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
  354. is ready to process incoming requests or for :timeout: seconds max.
  355. """
  356. self.start()
  357. if await_ready and not self.ready.wait(timeout=timeout):
  358. raise TimeoutError("ModuleContainer didn't notify .ready in {timeout} seconds")
  359. @property
  360. def ready(self) -> mp.synchronize.Event:
  361. """
  362. An event (multiprocessing.Event) that is set when the container is ready to process requests.
  363. Example
  364. =======
  365. >>> container.start()
  366. >>> container.ready.wait(timeout=10)
  367. >>> print("Container ready" if container.ready.is_set() else "Container didn't start in 10 seconds")
  368. """
  369. return self.runtime.ready # mp.Event that is true if self is ready to process batches
  370. def shutdown(self):
  371. """
  372. Gracefully terminate the container, process-safe.
  373. Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
  374. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
  375. """
  376. self.online_announcer.stop.set()
  377. self.online_announcer.join()
  378. declare_active_modules(
  379. self.dht,
  380. self.module_backends.keys(),
  381. expiration_time=get_dht_time() + self.expiration,
  382. state=ServerState.OFFLINE,
  383. throughput=self.throughput,
  384. )
  385. logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
  386. self.ready.clear()
  387. for handler in self.conn_handlers:
  388. handler.shutdown()
  389. logger.debug("Connection handlers terminated")
  390. if self.checkpoint_saver is not None:
  391. self.checkpoint_saver.stop.set()
  392. self.checkpoint_saver.join()
  393. logger.debug(f"Shutting down pools")
  394. for pool in self.runtime.pools:
  395. if pool.is_alive():
  396. pool.shutdown()
  397. logger.debug(f"Shutting down runtime")
  398. self.runtime.shutdown()
  399. logger.info("Module container shut down succesfully")
  400. class ModuleAnnouncerThread(threading.Thread):
  401. """Periodically announces that this container hosts the specified modules, visible to all DHT peers"""
  402. def __init__(
  403. self,
  404. module_uids: List[str],
  405. dht: DHT,
  406. state: ServerState,
  407. *,
  408. throughput: float,
  409. update_period: float = 30,
  410. expiration: float,
  411. **kwargs,
  412. ):
  413. super().__init__(**kwargs)
  414. self.module_uids = module_uids
  415. self.dht = dht
  416. self.state = state
  417. self.throughput = throughput
  418. self.update_period = update_period
  419. self.expiration = expiration
  420. self.stop = threading.Event()
  421. def run(self) -> None:
  422. while True:
  423. declare_active_modules(
  424. self.dht,
  425. self.module_uids,
  426. expiration_time=get_dht_time() + self.expiration,
  427. state=self.state,
  428. throughput=self.throughput,
  429. )
  430. if self.stop.wait(self.update_period):
  431. break