server.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. from __future__ import annotations
  2. import multiprocessing as mp
  3. import threading
  4. from typing import Dict, Optional, Sequence, Union
  5. import torch
  6. from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
  7. from hivemind.moe.server.dht_handler import DHTHandlerThread
  8. from hivemind.moe.server.layers import add_custom_models_from_file
  9. from hivemind.moe.server.runtime import Runtime
  10. from hivemind.proto.runtime_pb2 import CompressionType
  11. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  12. from src import declare_active_modules
  13. from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
  14. from src.server.backend import TransformerBackend
  15. from src.server.cache import MemoryCache
  16. from src.server.handler import TransformerConnectionHandler
  17. use_hivemind_log_handler("in_root_logger")
  18. logger = get_logger(__file__)
  19. class Server(threading.Thread):
  20. """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
  21. def __init__(
  22. self,
  23. dht: DHT,
  24. module_backends: Dict[str, TransformerBackend],
  25. *,
  26. device: torch.device,
  27. num_connection_handlers: int = 8,
  28. update_period: float = 30,
  29. expiration: Optional[float] = None,
  30. start: bool,
  31. **kwargs,
  32. ):
  33. threading.Thread.__init__(self)
  34. self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
  35. self.conn_handlers = [
  36. TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
  37. ]
  38. self.runtime = Runtime(self.module_backends, device=device, **kwargs)
  39. self.dht_handler_thread = ModuleAnnouncerThread(
  40. self.module_backends, dht, update_period, expiration, daemon=True
  41. )
  42. self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
  43. if start:
  44. self.run_in_background(await_ready=True)
  45. def run(self):
  46. """
  47. Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
  48. runs Runtime (self.runtime) to process incoming requests.
  49. """
  50. logger.info(f"Serving {len(self.module_backends)} blocks:")
  51. for expert_name, backend in self.module_backends.items():
  52. num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
  53. logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
  54. if not self.dht.is_alive():
  55. self.dht.run_in_background(await_ready=True)
  56. if self.module_backends:
  57. self.dht_handler_thread.start()
  58. if self.checkpoint_saver is not None:
  59. self.checkpoint_saver.start()
  60. for process in self.conn_handlers:
  61. if not process.is_alive():
  62. process.start()
  63. process.ready.result()
  64. try:
  65. self.runtime.run()
  66. finally:
  67. self.shutdown()
  68. # noinspection PyMethodOverriding
  69. @classmethod
  70. def create(
  71. cls,
  72. prefix: str,
  73. converted_model_name_or_path: str,
  74. num_blocks: Optional[int] = None,
  75. block_indices: Optional[str] = None,
  76. num_handlers: Optional[int] = None,
  77. min_batch_size: int = 1,
  78. max_batch_size: int = 4096,
  79. torch_dtype: str = "auto",
  80. cache_size_bytes: Optional[int] = None,
  81. device: Union[str, torch.device] = None,
  82. initial_peers: Sequence[str] = (),
  83. compression=CompressionType.NONE,
  84. stats_report_interval: Optional[int] = None,
  85. custom_module_path=None,
  86. update_period: float = 30,
  87. expiration: Optional[float] = None,
  88. use_auth_token: Optional[str] = None,
  89. *,
  90. start: bool,
  91. **kwargs,
  92. ) -> Server:
  93. """Create a server with one or more bloom blocks. See run_server.py for documentation."""
  94. if custom_module_path is not None:
  95. add_custom_models_from_file(custom_module_path)
  96. assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
  97. dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
  98. visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
  99. logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
  100. device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  101. memory_cache = MemoryCache(device, cache_size_bytes)
  102. if isinstance(torch_dtype, str):
  103. torch_dtype = DTYPE_MAP[torch_dtype]
  104. assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
  105. if block_indices is not None:
  106. try:
  107. start, end = block_indices.split(":")
  108. start, end = map(int, map(str.strip, (start, end)))
  109. except Exception as e:
  110. logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:33)")
  111. raise
  112. block_indices = range(start, end)
  113. else:
  114. assert num_blocks is not None
  115. block_indices = range(num_blocks) # TODO replace with proper load balancing
  116. block_config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=True)
  117. # initialize modules
  118. blocks = {}
  119. for block_index in block_indices:
  120. module_uid = f"{prefix}.{block_index}"
  121. block = load_pretrained_block(
  122. converted_model_name_or_path,
  123. block_index,
  124. block_config,
  125. torch_dtype=torch_dtype,
  126. use_auth_token=use_auth_token
  127. )
  128. for param in block.parameters():
  129. param.requires_grad = False
  130. blocks[module_uid] = TransformerBackend(
  131. module_uid,
  132. block,
  133. memory_cache=memory_cache,
  134. args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
  135. kwargs_schema={},
  136. outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
  137. min_batch_size=min_batch_size,
  138. max_batch_size=max_batch_size,
  139. )
  140. num_handlers = num_handlers if num_handlers is not None else len(blocks) * 4
  141. return cls(
  142. dht,
  143. blocks,
  144. num_connection_handlers=num_handlers,
  145. device=device,
  146. stats_report_interval=stats_report_interval,
  147. update_period=update_period,
  148. expiration=expiration,
  149. start=start,
  150. )
  151. def run_in_background(self, await_ready=True, timeout=None):
  152. """
  153. Starts Server in a background thread. if await_ready, this method will wait until background server
  154. is ready to process incoming requests or for :timeout: seconds max.
  155. """
  156. self.start()
  157. if await_ready and not self.ready.wait(timeout=timeout):
  158. raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
  159. @property
  160. def ready(self) -> mp.synchronize.Event:
  161. """
  162. An event (multiprocessing.Event) that is set when the server is ready to process requests.
  163. Example
  164. =======
  165. >>> server.start()
  166. >>> server.ready.wait(timeout=10)
  167. >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
  168. """
  169. return self.runtime.ready # mp.Event that is true if self is ready to process batches
  170. def shutdown(self):
  171. """
  172. Gracefully terminate the server, process-safe.
  173. Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
  174. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
  175. """
  176. self.ready.clear()
  177. for process in self.conn_handlers:
  178. process.terminate()
  179. process.join()
  180. logger.debug("Connection handlers terminated")
  181. if self.module_backends:
  182. self.dht_handler_thread.stop.set()
  183. self.dht_handler_thread.join()
  184. if self.checkpoint_saver is not None:
  185. self.checkpoint_saver.stop.set()
  186. self.checkpoint_saver.join()
  187. self.dht.shutdown()
  188. self.dht.join()
  189. logger.debug(f"Shutting down runtime")
  190. self.runtime.shutdown()
  191. logger.info("Server shutdown succesfully")
  192. class ModuleAnnouncerThread(threading.Thread):
  193. """Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
  194. def __init__(
  195. self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
  196. ):
  197. super().__init__(**kwargs)
  198. if expiration is None:
  199. expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
  200. self.module_backends = module_backends
  201. self.dht = dht
  202. self.update_period = update_period
  203. self.expiration = expiration
  204. self.stop = threading.Event()
  205. def run(self) -> None:
  206. declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
  207. while not self.stop.wait(self.update_period):
  208. declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)