server.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. from __future__ import annotations
  2. import gc
  3. import multiprocessing as mp
  4. import random
  5. import threading
  6. import time
  7. from typing import Dict, List, Optional, Union
  8. import numpy as np
  9. import psutil
  10. import torch
  11. from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
  12. from hivemind.moe.server.layers import add_custom_models_from_file
  13. from hivemind.moe.server.runtime import Runtime
  14. from hivemind.proto.runtime_pb2 import CompressionType
  15. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  16. from src import BloomConfig, declare_active_modules
  17. from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
  18. from src.constants import PUBLIC_INITIAL_PEERS
  19. from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
  20. from src.dht_utils import get_remote_module_infos
  21. from src.server import block_selection
  22. from src.server.backend import TransformerBackend
  23. from src.server.cache import MemoryCache
  24. from src.server.handler import TransformerConnectionHandler
  25. from src.server.throughput import get_host_throughput
  26. from src.utils.convert_8bit import replace_8bit_linear
  27. use_hivemind_log_handler("in_root_logger")
  28. logger = get_logger(__file__)
  29. class Server:
  30. """
  31. Runs ModuleContainer, periodically checks that the network is balanced,
  32. restarts the ModuleContainer with other layers if the imbalance is significant
  33. """
  34. def __init__(
  35. self,
  36. *,
  37. initial_peers: List[str],
  38. prefix: Optional[str],
  39. converted_model_name_or_path: str,
  40. throughput: Union[float, str],
  41. num_blocks: Optional[int] = None,
  42. block_indices: Optional[str] = None,
  43. num_handlers: int = 8,
  44. min_batch_size: int = 1,
  45. max_batch_size: int = 2048,
  46. inference_max_length: int = 2048,
  47. torch_dtype: str = "auto",
  48. revision: str = "main",
  49. cache_dir: Optional[str] = None,
  50. attn_cache_size: Optional[int] = None,
  51. alloc_timeout: float = 60,
  52. device: Optional[Union[str, torch.device]] = None,
  53. compression=CompressionType.NONE,
  54. stats_report_interval: Optional[int] = None,
  55. custom_module_path=None,
  56. update_period: float = 30,
  57. expiration: Optional[float] = None,
  58. request_timeout: float = 3 * 60,
  59. session_timeout: float = 30 * 60,
  60. step_timeout: float = 5 * 60,
  61. prefetch_batches: int = 1,
  62. sender_threads: int = 1,
  63. balance_quality: float = 0.75,
  64. mean_balance_check_period: float = 60,
  65. mean_block_selection_delay: float = 0.5,
  66. use_auth_token: Optional[str] = None,
  67. load_in_8bit: bool = False,
  68. **kwargs,
  69. ):
  70. """Create a server with one or more bloom blocks. See run_server.py for documentation."""
  71. self.converted_model_name_or_path = converted_model_name_or_path
  72. self.num_handlers = num_handlers
  73. self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
  74. self.inference_max_length = inference_max_length
  75. self.cache_dir = cache_dir
  76. self.attn_cache_size = attn_cache_size
  77. self.compression = compression
  78. self.stats_report_interval, self.update_period = stats_report_interval, update_period
  79. self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
  80. self.use_auth_token = use_auth_token
  81. self.load_in_8bit = load_in_8bit
  82. if custom_module_path is not None:
  83. add_custom_models_from_file(custom_module_path)
  84. if prefix is None:
  85. prefix = converted_model_name_or_path
  86. assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
  87. f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
  88. f"Please specify --prefix manually when starting a server"
  89. )
  90. logger.info(f"Automatic dht prefix: {prefix}")
  91. self.prefix = prefix
  92. if expiration is None:
  93. expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
  94. self.expiration = expiration
  95. self.request_timeout = request_timeout
  96. self.session_timeout, self.step_timeout = session_timeout, step_timeout
  97. self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
  98. visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
  99. if initial_peers == PUBLIC_INITIAL_PEERS:
  100. logger.info("Connecting to the public Petals swarm")
  101. else:
  102. logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
  103. device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  104. self.device = device
  105. self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
  106. assert isinstance(throughput, float) or throughput in ["auto", "eval"]
  107. if throughput in ["auto", "eval"]:
  108. throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
  109. self.throughput = throughput
  110. if isinstance(torch_dtype, str):
  111. torch_dtype = DTYPE_MAP[torch_dtype]
  112. assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
  113. self.torch_dtype = torch_dtype
  114. self.block_config = BloomConfig.from_pretrained(
  115. converted_model_name_or_path,
  116. use_auth_token=use_auth_token,
  117. revision=revision,
  118. )
  119. self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
  120. assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
  121. if block_indices is not None:
  122. try:
  123. first_block_index, last_block_index = block_indices.split(":")
  124. first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
  125. except Exception as e:
  126. logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
  127. raise
  128. block_indices = range(first_block_index, last_block_index)
  129. self.strict_block_indices, self.num_blocks = block_indices, num_blocks
  130. self.balance_quality = balance_quality
  131. self.mean_balance_check_period = mean_balance_check_period
  132. self.mean_block_selection_delay = mean_block_selection_delay
  133. self.stop = threading.Event()
  134. def run(self):
  135. while True:
  136. block_indices = self._choose_blocks()
  137. self.module_container = ModuleContainer.create(
  138. dht=self.dht,
  139. prefix=self.prefix,
  140. converted_model_name_or_path=self.converted_model_name_or_path,
  141. block_config=self.block_config,
  142. memory_cache=self.memory_cache,
  143. throughput=self.throughput,
  144. block_indices=block_indices,
  145. num_handlers=self.num_handlers,
  146. min_batch_size=self.min_batch_size,
  147. max_batch_size=self.max_batch_size,
  148. inference_max_length=self.inference_max_length,
  149. torch_dtype=self.torch_dtype,
  150. cache_dir=self.cache_dir,
  151. device=self.device,
  152. compression=self.compression,
  153. stats_report_interval=self.stats_report_interval,
  154. update_period=self.update_period,
  155. expiration=self.expiration,
  156. request_timeout=self.request_timeout,
  157. session_timeout=self.session_timeout,
  158. step_timeout=self.step_timeout,
  159. prefetch_batches=self.prefetch_batches,
  160. sender_threads=self.sender_threads,
  161. use_auth_token=self.use_auth_token,
  162. load_in_8bit=self.load_in_8bit,
  163. start=True,
  164. )
  165. try:
  166. self.module_container.ready.wait()
  167. while True:
  168. timeout = random.random() * 2 * self.mean_balance_check_period
  169. # TODO: Follow ModuleContainer status (to restart/stop if it crashes)
  170. if self.stop.wait(timeout):
  171. return
  172. if self._should_choose_other_blocks():
  173. logger.info("Swarm is imbalanced, server will load other blocks")
  174. break # Stop serving this set of modules
  175. finally:
  176. self.module_container.shutdown()
  177. self._clean_memory_and_fds()
  178. def _clean_memory_and_fds(self):
  179. del self.module_container
  180. gc.collect() # In particular, this closes unused file descriptors
  181. cur_proc = psutil.Process()
  182. num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)]
  183. logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left")
  184. def _choose_blocks(self) -> List[int]:
  185. if self.strict_block_indices is not None:
  186. return self.strict_block_indices
  187. assert self.num_blocks is not None
  188. # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
  189. # this delay decreases the probability of a race condition while choosing the best blocks to serve.
  190. time.sleep(random.random() * 2 * self.mean_block_selection_delay)
  191. module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
  192. return block_selection.choose_best_blocks(self.num_blocks, module_infos)
  193. def _should_choose_other_blocks(self) -> bool:
  194. if self.strict_block_indices is not None:
  195. return False
  196. module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
  197. return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
  198. def shutdown(self):
  199. self.stop.set()
  200. self.dht.shutdown()
  201. self.dht.join()
  202. class ModuleContainer(threading.Thread):
  203. """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
  204. # noinspection PyMethodOverriding
  205. @classmethod
  206. def create(
  207. cls,
  208. *,
  209. dht: DHT,
  210. prefix: str,
  211. converted_model_name_or_path: str,
  212. block_config: BloomConfig,
  213. memory_cache: MemoryCache,
  214. throughput: float,
  215. block_indices: List[int],
  216. min_batch_size: int,
  217. max_batch_size: int,
  218. torch_dtype: torch.dtype,
  219. cache_dir: Optional[str],
  220. device: Union[str, torch.device],
  221. compression: CompressionType,
  222. update_period: float,
  223. expiration: Optional[float],
  224. use_auth_token: Optional[str],
  225. load_in_8bit: bool,
  226. **kwargs,
  227. ) -> ModuleContainer:
  228. module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
  229. joining_announcer = ModuleAnnouncerThread(
  230. module_uids,
  231. dht,
  232. ServerState.JOINING,
  233. throughput=throughput,
  234. update_period=update_period,
  235. expiration=expiration,
  236. daemon=True,
  237. )
  238. joining_announcer.start()
  239. logger.info(f"Announced that blocks {block_indices} are joining")
  240. try:
  241. blocks = {}
  242. for module_uid, block_index in zip(module_uids, block_indices):
  243. block = load_pretrained_block(
  244. converted_model_name_or_path,
  245. block_index,
  246. block_config,
  247. torch_dtype=torch_dtype,
  248. use_auth_token=use_auth_token,
  249. cache_dir=cache_dir,
  250. )
  251. if load_in_8bit:
  252. block = replace_8bit_linear(block)
  253. block = block.to(device)
  254. for param in block.parameters():
  255. param.requires_grad = False
  256. blocks[module_uid] = TransformerBackend(
  257. module_uid,
  258. block,
  259. memory_cache=memory_cache,
  260. backend_dtype=None if torch_dtype == "auto" else torch_dtype,
  261. args_schema=(
  262. BatchTensorDescriptor(
  263. 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
  264. ),
  265. ),
  266. kwargs_schema={},
  267. outputs_schema=(
  268. BatchTensorDescriptor(
  269. 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
  270. ),
  271. ),
  272. min_batch_size=min_batch_size,
  273. max_batch_size=max_batch_size,
  274. )
  275. except:
  276. joining_announcer.stop.set()
  277. joining_announcer.join()
  278. declare_active_modules(
  279. dht,
  280. module_uids,
  281. expiration_time=get_dht_time() + expiration,
  282. state=ServerState.OFFLINE,
  283. throughput=throughput,
  284. )
  285. logger.info(f"Announced that blocks {module_uids} are offline")
  286. raise
  287. else:
  288. joining_announcer.stop.set()
  289. joining_announcer.join()
  290. return cls(
  291. dht,
  292. blocks,
  293. throughput=throughput,
  294. device=device,
  295. update_period=update_period,
  296. expiration=expiration,
  297. **kwargs,
  298. )
  299. def __init__(
  300. self,
  301. dht: DHT,
  302. module_backends: Dict[str, TransformerBackend],
  303. *,
  304. inference_max_length: int,
  305. num_handlers: int,
  306. throughput: float,
  307. update_period: float,
  308. expiration: Optional[float] = None,
  309. request_timeout: float,
  310. session_timeout: float,
  311. step_timeout: float,
  312. start: bool,
  313. **kwargs,
  314. ):
  315. super().__init__()
  316. self.dht, self.module_backends = dht, module_backends
  317. self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
  318. self.conn_handlers = [
  319. TransformerConnectionHandler(
  320. dht,
  321. self.module_backends,
  322. inference_max_length=inference_max_length,
  323. request_timeout=request_timeout,
  324. session_timeout=session_timeout,
  325. step_timeout=step_timeout,
  326. )
  327. for _ in range(num_handlers)
  328. ]
  329. self.runtime = Runtime(self.module_backends, **kwargs)
  330. self.online_announcer = ModuleAnnouncerThread(
  331. list(self.module_backends.keys()),
  332. dht,
  333. ServerState.ONLINE,
  334. throughput=throughput,
  335. update_period=update_period,
  336. expiration=expiration,
  337. daemon=True,
  338. )
  339. self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
  340. if start:
  341. self.run_in_background(await_ready=True)
  342. def run(self):
  343. """
  344. Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
  345. runs Runtime (self.runtime) to process incoming requests.
  346. """
  347. if not self.dht.is_alive():
  348. self.dht.run_in_background(await_ready=True)
  349. self.online_announcer.start()
  350. if self.checkpoint_saver is not None:
  351. self.checkpoint_saver.start()
  352. for handler in self.conn_handlers:
  353. handler.run_in_background()
  354. self.runtime.run()
  355. def run_in_background(self, await_ready=True, timeout=None):
  356. """
  357. Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
  358. is ready to process incoming requests or for :timeout: seconds max.
  359. """
  360. self.start()
  361. if await_ready and not self.ready.wait(timeout=timeout):
  362. raise TimeoutError("ModuleContainer didn't notify .ready in {timeout} seconds")
  363. @property
  364. def ready(self) -> mp.synchronize.Event:
  365. """
  366. An event (multiprocessing.Event) that is set when the container is ready to process requests.
  367. Example
  368. =======
  369. >>> container.start()
  370. >>> container.ready.wait(timeout=10)
  371. >>> print("Container ready" if container.ready.is_set() else "Container didn't start in 10 seconds")
  372. """
  373. return self.runtime.ready # mp.Event that is true if self is ready to process batches
  374. def shutdown(self):
  375. """
  376. Gracefully terminate the container, process-safe.
  377. Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
  378. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
  379. """
  380. self.online_announcer.stop.set()
  381. self.online_announcer.join()
  382. declare_active_modules(
  383. self.dht,
  384. self.module_backends.keys(),
  385. expiration_time=get_dht_time() + self.expiration,
  386. state=ServerState.OFFLINE,
  387. throughput=self.throughput,
  388. )
  389. logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
  390. self.ready.clear()
  391. for handler in self.conn_handlers:
  392. handler.shutdown()
  393. logger.debug("Connection handlers terminated")
  394. if self.checkpoint_saver is not None:
  395. self.checkpoint_saver.stop.set()
  396. self.checkpoint_saver.join()
  397. logger.debug(f"Shutting down pools")
  398. for pool in self.runtime.pools:
  399. if pool.is_alive():
  400. pool.shutdown()
  401. logger.debug(f"Shutting down runtime")
  402. self.runtime.shutdown()
  403. logger.info("Module container shut down succesfully")
  404. class ModuleAnnouncerThread(threading.Thread):
  405. """Periodically announces that this container hosts the specified modules, visible to all DHT peers"""
  406. def __init__(
  407. self,
  408. module_uids: List[str],
  409. dht: DHT,
  410. state: ServerState,
  411. *,
  412. throughput: float,
  413. update_period: float = 30,
  414. expiration: float,
  415. **kwargs,
  416. ):
  417. super().__init__(**kwargs)
  418. self.module_uids = module_uids
  419. self.dht = dht
  420. self.state = state
  421. self.throughput = throughput
  422. self.update_period = update_period
  423. self.expiration = expiration
  424. self.stop = threading.Event()
  425. def run(self) -> None:
  426. while True:
  427. declare_active_modules(
  428. self.dht,
  429. self.module_uids,
  430. expiration_time=get_dht_time() + self.expiration,
  431. state=self.state,
  432. throughput=self.throughput,
  433. )
  434. if self.stop.wait(self.update_period):
  435. break