server.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. from __future__ import annotations
  2. import gc
  3. import math
  4. import multiprocessing as mp
  5. import random
  6. import threading
  7. import time
  8. from typing import Dict, List, Optional, Sequence, Union
  9. import torch
  10. from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
  11. from hivemind.moe.server.layers import add_custom_models_from_file
  12. from hivemind.moe.server.runtime import Runtime
  13. from hivemind.proto.runtime_pb2 import CompressionType
  14. from hivemind.utils.logging import get_logger
  15. from transformers import BloomConfig
  16. from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
  17. from petals.constants import PUBLIC_INITIAL_PEERS
  18. from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
  19. from petals.dht_utils import declare_active_modules, get_remote_module_infos
  20. from petals.server import block_selection
  21. from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
  22. from petals.server.block_utils import get_block_size
  23. from petals.server.handler import TransformerConnectionHandler
  24. from petals.server.memory_cache import MemoryCache
  25. from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
  26. from petals.server.throughput import get_dtype_name, get_server_throughput
  27. from petals.utils.convert_block import check_device_balance, convert_block
  28. from petals.utils.disk_cache import DEFAULT_CACHE_DIR
  29. logger = get_logger(__name__)
  30. class Server:
  31. """
  32. Runs ModuleContainer, periodically checks that the network is balanced,
  33. restarts the ModuleContainer with other layers if the imbalance is significant
  34. """
  35. def __init__(
  36. self,
  37. *,
  38. initial_peers: List[str],
  39. prefix: Optional[str],
  40. converted_model_name_or_path: str,
  41. throughput: Union[float, str],
  42. num_blocks: Optional[int] = None,
  43. block_indices: Optional[str] = None,
  44. num_handlers: int = 8,
  45. min_batch_size: int = 1,
  46. max_batch_size: int = 2048,
  47. inference_max_length: int = 2048,
  48. torch_dtype: str = "auto",
  49. revision: str = "main",
  50. cache_dir: Optional[str] = None,
  51. max_disk_space: Optional[int] = None,
  52. attn_cache_size: Optional[int] = None,
  53. alloc_timeout: float = 60,
  54. device: Optional[Union[str, torch.device]] = None,
  55. compression=CompressionType.NONE,
  56. stats_report_interval: Optional[int] = None,
  57. custom_module_path=None,
  58. update_period: float = 150,
  59. expiration: Optional[float] = None,
  60. request_timeout: float = 3 * 60,
  61. session_timeout: float = 30 * 60,
  62. step_timeout: float = 5 * 60,
  63. prefetch_batches: int = 1,
  64. sender_threads: int = 1,
  65. balance_quality: float = 0.75,
  66. mean_balance_check_period: float = 120,
  67. mean_block_selection_delay: float = 2.5,
  68. use_auth_token: Optional[str] = None,
  69. load_in_8bit: Optional[bool] = None,
  70. tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
  71. skip_reachability_check: bool = False,
  72. dht_client_mode: Optional[bool] = None,
  73. use_relay: bool = True,
  74. use_auto_relay: bool = True,
  75. **kwargs,
  76. ):
  77. """Create a server with one or more bloom blocks. See run_server.py for documentation."""
  78. self.converted_model_name_or_path = converted_model_name_or_path
  79. self.num_handlers = num_handlers
  80. self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
  81. self.inference_max_length = inference_max_length
  82. self.compression = compression
  83. self.stats_report_interval, self.update_period = stats_report_interval, update_period
  84. self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
  85. self.use_auth_token = use_auth_token
  86. if custom_module_path is not None:
  87. add_custom_models_from_file(custom_module_path)
  88. if prefix is None:
  89. prefix = converted_model_name_or_path
  90. assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
  91. f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
  92. f"Please specify --prefix manually when starting a server"
  93. )
  94. logger.debug(f"Automatic dht prefix: {prefix}")
  95. self.prefix = prefix
  96. if expiration is None:
  97. expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
  98. self.expiration = expiration
  99. self.request_timeout = request_timeout
  100. self.session_timeout, self.step_timeout = session_timeout, step_timeout
  101. self.block_config = BloomConfig.from_pretrained(
  102. converted_model_name_or_path,
  103. use_auth_token=use_auth_token,
  104. revision=revision,
  105. )
  106. self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
  107. if dht_client_mode is None:
  108. is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
  109. dht_client_mode = is_reachable is False # if could not check reachability (returns None), run a full peer
  110. logger.info(f"This server is accessible {'via relays' if dht_client_mode else 'directly'}")
  111. self.dht = DHT(
  112. initial_peers=initial_peers,
  113. start=True,
  114. num_workers=self.block_config.n_layer,
  115. use_relay=use_relay,
  116. use_auto_relay=use_auto_relay,
  117. client_mode=dht_client_mode,
  118. **kwargs,
  119. )
  120. self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None
  121. visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
  122. if initial_peers == PUBLIC_INITIAL_PEERS:
  123. logger.info("Connecting to the public swarm")
  124. else:
  125. logger.info(f"Connecting to a private swarm, initial peers: {initial_peers}")
  126. logger.info(f"Running a server on {visible_maddrs_str}")
  127. self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
  128. if device is None:
  129. device = "cuda" if torch.cuda.is_available() else "cpu"
  130. device = torch.device(device)
  131. if device.type == "cuda" and device.index is None:
  132. device = torch.device(device.type, index=0)
  133. self.device = device
  134. if isinstance(torch_dtype, str):
  135. torch_dtype = DTYPE_MAP[torch_dtype]
  136. assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
  137. self.torch_dtype = torch_dtype
  138. if tensor_parallel_devices is None:
  139. tensor_parallel_devices = (device,)
  140. self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices))
  141. if len(self.tensor_parallel_devices) > 1:
  142. logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}")
  143. check_device_balance(self.tensor_parallel_devices)
  144. if load_in_8bit is None:
  145. load_in_8bit = device.type == "cuda"
  146. self.load_in_8bit = load_in_8bit
  147. logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
  148. assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
  149. if num_blocks is None and block_indices is None:
  150. num_blocks = self._choose_num_blocks()
  151. if block_indices is not None:
  152. try:
  153. first_block_index, last_block_index = block_indices.split(":")
  154. first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
  155. except Exception as e:
  156. raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
  157. block_indices = range(first_block_index, last_block_index)
  158. num_blocks = len(block_indices)
  159. self.strict_block_indices, self.num_blocks = block_indices, num_blocks
  160. gib = 1024**3
  161. if attn_cache_size is None:
  162. # Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
  163. attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
  164. self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout
  165. logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
  166. if cache_dir is None:
  167. cache_dir = DEFAULT_CACHE_DIR
  168. self.cache_dir = cache_dir
  169. self.max_disk_space = max_disk_space
  170. assert isinstance(throughput, float) or throughput in ["auto", "eval"]
  171. if throughput in ["auto", "eval"]:
  172. throughput = get_server_throughput(
  173. self.block_config,
  174. device,
  175. torch_dtype,
  176. num_blocks=num_blocks,
  177. load_in_8bit=load_in_8bit,
  178. tensor_parallel_devices=self.tensor_parallel_devices,
  179. force_eval=(throughput == "eval"),
  180. cache_dir=cache_dir,
  181. )
  182. self.throughput = throughput
  183. self.balance_quality = balance_quality
  184. self.mean_balance_check_period = mean_balance_check_period
  185. self.mean_block_selection_delay = mean_block_selection_delay
  186. self.stop = threading.Event()
  187. def _choose_num_blocks(self) -> int:
  188. assert self.device.type == "cuda", (
  189. "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
  190. "CPU-only servers in the public swarm are discouraged since they are much slower"
  191. )
  192. num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
  193. if num_devices > 1:
  194. memory_per_device = tuple(
  195. torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
  196. )
  197. total_memory = min(memory_per_device) * num_devices
  198. if max(memory_per_device) / min(memory_per_device) > 1.5:
  199. raise ValueError(
  200. "GPU devices have highly uneven memory, which makes tensor parallelism inefficient. "
  201. "Please launch individual servers on each GPU or set --num_blocks manually to "
  202. "override this exception."
  203. )
  204. else:
  205. total_memory = torch.cuda.get_device_properties(self.device).total_memory
  206. block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit)
  207. # The estimates below are for bigscience/bloom-petals, serving as an upper bound for other models
  208. gib = 1024**3
  209. attn_cache_per_block = 0.5 * gib * num_devices # TODO: This does not account for manually set --attn_cache_size
  210. autograd_memory = 2 * gib * num_devices # GPU memory used for intermediate tensors in rpc_backward
  211. num_blocks = math.floor((total_memory - autograd_memory) / (block_size + attn_cache_per_block))
  212. assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
  213. logger.info(
  214. f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
  215. f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
  216. )
  217. return min(num_blocks, self.block_config.n_layer)
  218. def run(self):
  219. while True:
  220. block_indices = self._choose_blocks()
  221. self.module_container = ModuleContainer.create(
  222. dht=self.dht,
  223. prefix=self.prefix,
  224. converted_model_name_or_path=self.converted_model_name_or_path,
  225. block_config=self.block_config,
  226. attn_cache_size=self.attn_cache_size,
  227. alloc_timeout=self.alloc_timeout,
  228. throughput=self.throughput,
  229. block_indices=block_indices,
  230. num_handlers=self.num_handlers,
  231. min_batch_size=self.min_batch_size,
  232. max_batch_size=self.max_batch_size,
  233. inference_max_length=self.inference_max_length,
  234. torch_dtype=self.torch_dtype,
  235. cache_dir=self.cache_dir,
  236. max_disk_space=self.max_disk_space,
  237. device=self.device,
  238. compression=self.compression,
  239. stats_report_interval=self.stats_report_interval,
  240. update_period=self.update_period,
  241. expiration=self.expiration,
  242. request_timeout=self.request_timeout,
  243. session_timeout=self.session_timeout,
  244. step_timeout=self.step_timeout,
  245. prefetch_batches=self.prefetch_batches,
  246. sender_threads=self.sender_threads,
  247. use_auth_token=self.use_auth_token,
  248. load_in_8bit=self.load_in_8bit,
  249. tensor_parallel_devices=self.tensor_parallel_devices,
  250. should_validate_reachability=self.should_validate_reachability,
  251. start=True,
  252. )
  253. try:
  254. self.module_container.ready.wait()
  255. while True:
  256. timeout = random.random() * 2 * self.mean_balance_check_period
  257. if self.stop.wait(timeout):
  258. return
  259. if not self.module_container.is_healthy():
  260. logger.warning("One of subprocesses crashed, restarting the server")
  261. break
  262. if self._should_choose_other_blocks():
  263. logger.info("Swarm is imbalanced, server will load other blocks")
  264. break # Stop serving this set of modules
  265. finally:
  266. self.module_container.shutdown()
  267. self._clean_memory_and_fds()
  268. def _clean_memory_and_fds(self):
  269. del self.module_container
  270. gc.collect() # In particular, this closes unused file descriptors
  271. if self.device.type == "cuda":
  272. torch.cuda.empty_cache()
  273. allocated_vram = torch.cuda.memory_allocated(self.device)
  274. reserved_vram = torch.cuda.memory_reserved(self.device)
  275. gib = 1024**3
  276. logger.info(
  277. f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
  278. f"{reserved_vram / gib:.1f} GiB reserved memory"
  279. )
  280. def _choose_blocks(self) -> List[int]:
  281. if self.strict_block_indices is not None:
  282. return self.strict_block_indices
  283. # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
  284. # this delay decreases the probability of a race condition while choosing the best blocks to serve.
  285. time.sleep(random.random() * 2 * self.mean_block_selection_delay)
  286. module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
  287. return block_selection.choose_best_blocks(self.num_blocks, module_infos)
  288. def _should_choose_other_blocks(self) -> bool:
  289. if self.strict_block_indices is not None:
  290. return False
  291. module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
  292. return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
  293. def shutdown(self):
  294. self.stop.set()
  295. if self.reachability_protocol is not None:
  296. self.reachability_protocol.shutdown()
  297. self.dht.shutdown()
  298. self.dht.join()
  299. class ModuleContainer(threading.Thread):
  300. """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
  301. # noinspection PyMethodOverriding
  302. @classmethod
  303. def create(
  304. cls,
  305. *,
  306. dht: DHT,
  307. prefix: str,
  308. converted_model_name_or_path: str,
  309. block_config: BloomConfig,
  310. attn_cache_size: int,
  311. alloc_timeout: float,
  312. throughput: float,
  313. block_indices: List[int],
  314. min_batch_size: int,
  315. max_batch_size: int,
  316. torch_dtype: torch.dtype,
  317. cache_dir: str,
  318. max_disk_space: int,
  319. device: Union[str, torch.device],
  320. compression: CompressionType,
  321. update_period: float,
  322. expiration: Optional[float],
  323. use_auth_token: Optional[str],
  324. load_in_8bit: bool,
  325. tensor_parallel_devices: Sequence[torch.device],
  326. should_validate_reachability: bool,
  327. **kwargs,
  328. ) -> ModuleContainer:
  329. module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
  330. joining_announcer = ModuleAnnouncerThread(
  331. module_uids,
  332. dht,
  333. ServerState.JOINING,
  334. throughput=throughput,
  335. update_period=update_period,
  336. expiration=expiration,
  337. daemon=True,
  338. )
  339. joining_announcer.start()
  340. logger.info(f"Announced that blocks {block_indices} are joining")
  341. assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
  342. memory_cache = MemoryCache(attn_cache_size, alloc_timeout)
  343. blocks = {}
  344. try:
  345. for module_uid, block_index in zip(module_uids, block_indices):
  346. block = load_pretrained_block(
  347. converted_model_name_or_path,
  348. block_index,
  349. block_config,
  350. torch_dtype=torch_dtype,
  351. use_auth_token=use_auth_token,
  352. cache_dir=cache_dir,
  353. max_disk_space=max_disk_space,
  354. )
  355. block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
  356. backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
  357. blocks[module_uid] = TransformerBackend(
  358. module_uid,
  359. block,
  360. config=block_config,
  361. memory_cache=memory_cache,
  362. backend_dtype=backend_dtype,
  363. args_schema=(
  364. BatchTensorDescriptor(
  365. 1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
  366. ),
  367. ),
  368. kwargs_schema={},
  369. outputs_schema=(
  370. BatchTensorDescriptor(
  371. 1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
  372. ),
  373. ),
  374. min_batch_size=min_batch_size,
  375. max_batch_size=max_batch_size,
  376. )
  377. if should_validate_reachability:
  378. validate_reachability(dht.peer_id)
  379. except:
  380. logger.debug("Shutting down backends")
  381. for backend in blocks.values():
  382. backend.shutdown()
  383. joining_announcer.stop.set()
  384. joining_announcer.join()
  385. declare_active_modules(
  386. dht,
  387. module_uids,
  388. expiration_time=get_dht_time() + expiration,
  389. state=ServerState.OFFLINE,
  390. throughput=throughput,
  391. )
  392. logger.info(f"Announced that blocks {module_uids} are offline")
  393. raise
  394. else:
  395. joining_announcer.stop.set()
  396. joining_announcer.join()
  397. merge_inference_pools_inplace(blocks)
  398. return cls(
  399. dht,
  400. blocks,
  401. throughput=throughput,
  402. update_period=update_period,
  403. expiration=expiration,
  404. **kwargs,
  405. )
  406. def __init__(
  407. self,
  408. dht: DHT,
  409. module_backends: Dict[str, TransformerBackend],
  410. *,
  411. inference_max_length: int,
  412. num_handlers: int,
  413. throughput: float,
  414. update_period: float,
  415. expiration: Optional[float] = None,
  416. request_timeout: float,
  417. session_timeout: float,
  418. step_timeout: float,
  419. start: bool,
  420. **kwargs,
  421. ):
  422. super().__init__()
  423. self.dht, self.module_backends = dht, module_backends
  424. self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
  425. self.conn_handlers = [
  426. TransformerConnectionHandler(
  427. dht,
  428. self.module_backends,
  429. inference_max_length=inference_max_length,
  430. request_timeout=request_timeout,
  431. session_timeout=session_timeout,
  432. step_timeout=step_timeout,
  433. )
  434. for _ in range(num_handlers)
  435. ]
  436. self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
  437. # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
  438. self.online_announcer = ModuleAnnouncerThread(
  439. list(self.module_backends.keys()),
  440. dht,
  441. ServerState.ONLINE,
  442. throughput=throughput,
  443. update_period=update_period,
  444. expiration=expiration,
  445. daemon=True,
  446. )
  447. if start:
  448. self.run_in_background(await_ready=True)
  449. def run(self):
  450. """
  451. Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
  452. runs Runtime (self.runtime) to process incoming requests.
  453. """
  454. if not self.dht.is_alive():
  455. self.dht.run_in_background(await_ready=True)
  456. self.online_announcer.start()
  457. for handler in self.conn_handlers:
  458. handler.run_in_background()
  459. self.runtime.run()
  460. def run_in_background(self, await_ready=True, timeout=None):
  461. """
  462. Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
  463. is ready to process incoming requests or for :timeout: seconds max.
  464. """
  465. self.start()
  466. if await_ready and not self.ready.wait(timeout=timeout):
  467. raise TimeoutError("ModuleContainer didn't notify .ready in {timeout} seconds")
  468. @property
  469. def ready(self) -> mp.synchronize.Event:
  470. """
  471. An event (multiprocessing.Event) that is set when the container is ready to process requests.
  472. Example
  473. =======
  474. >>> container.start()
  475. >>> container.ready.wait(timeout=10)
  476. >>> print("Container ready" if container.ready.is_set() else "Container didn't start in 10 seconds")
  477. """
  478. return self.runtime.ready # mp.Event that is true if self is ready to process batches
  479. def is_healthy(self) -> bool:
  480. return all(handler.is_alive() for handler in self.conn_handlers) and all(
  481. pool.is_alive() for pool in self.runtime.pools
  482. )
  483. def shutdown(self):
  484. """
  485. Gracefully terminate the container, process-safe.
  486. Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
  487. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
  488. """
  489. self.online_announcer.stop.set()
  490. self.online_announcer.join()
  491. declare_active_modules(
  492. self.dht,
  493. self.module_backends.keys(),
  494. expiration_time=get_dht_time() + self.expiration,
  495. state=ServerState.OFFLINE,
  496. throughput=self.throughput,
  497. )
  498. logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
  499. self.ready.clear()
  500. for handler in self.conn_handlers:
  501. handler.shutdown()
  502. logger.debug("Connection handlers terminated")
  503. if self.checkpoint_saver is not None:
  504. self.checkpoint_saver.stop.set()
  505. self.checkpoint_saver.join()
  506. logger.debug(f"Shutting down pools")
  507. for pool in self.runtime.pools:
  508. if pool.is_alive():
  509. pool.shutdown()
  510. logger.debug(f"Shutting down runtime")
  511. self.runtime.shutdown()
  512. logger.debug("Shutting down backends")
  513. for backend in self.module_backends.values():
  514. backend.shutdown()
  515. logger.info("Module container shut down successfully")
  516. class ModuleAnnouncerThread(threading.Thread):
  517. """Periodically announces that this container hosts the specified modules, visible to all DHT peers"""
  518. def __init__(
  519. self,
  520. module_uids: List[str],
  521. dht: DHT,
  522. state: ServerState,
  523. *,
  524. throughput: float,
  525. update_period: float = 30,
  526. expiration: float,
  527. **kwargs,
  528. ):
  529. super().__init__(**kwargs)
  530. self.module_uids = module_uids
  531. self.dht = dht
  532. self.state = state
  533. self.throughput = throughput
  534. self.update_period = update_period
  535. self.expiration = expiration
  536. self.stop = threading.Event()
  537. def run(self) -> None:
  538. while True:
  539. declare_active_modules(
  540. self.dht,
  541. self.module_uids,
  542. expiration_time=get_dht_time() + self.expiration,
  543. state=self.state,
  544. throughput=self.throughput,
  545. )
  546. if self.stop.wait(self.update_period):
  547. break
  548. class RuntimeWithDeduplicatedPools(Runtime):
  549. """A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool"""
  550. def __init__(self, *args, **kwargs):
  551. super().__init__(*args, **kwargs)
  552. self.pools = tuple(set(self.pools))