runtime.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import multiprocessing as mp
  2. import multiprocessing.pool
  3. import threading
  4. from collections import defaultdict
  5. from itertools import chain
  6. from queue import SimpleQueue
  7. from selectors import EVENT_READ, DefaultSelector
  8. from statistics import mean
  9. from time import time
  10. from typing import Dict, NamedTuple, Optional
  11. import torch
  12. from prefetch_generator import BackgroundGenerator
  13. from hivemind.moe.server.module_backend import ModuleBackend
  14. from hivemind.utils import get_logger
  15. logger = get_logger(__name__)
  16. class Runtime(threading.Thread):
  17. """
  18. A group of processes that processes incoming requests for multiple module backends on a shared device.
  19. Runtime is usually created and managed by Server, humans need not apply.
  20. For debugging, you can start runtime manually with .start() or .run()
  21. >>> module_backends = {'expert_name': ModuleBackend(**kwargs)}
  22. >>> runtime = Runtime(module_backends)
  23. >>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run()
  24. >>> runtime.ready.wait() # await for runtime to load all experts on device and create request pools
  25. >>> future = runtime.module_backends['expert_name'].forward_pool.submit_task(*module_inputs)
  26. >>> print("Returned:", future.result())
  27. >>> runtime.shutdown()
  28. :param module_backends: a dict [expert uid -> ModuleBackend]
  29. :param prefetch_batches: form up to this many batches in advance
  30. :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
  31. :param device: if specified, moves all experts and data to this device via .to(device=device).
  32. If you want to manually specify devices for each expert (in their forward pass), leave device=None (default)
  33. :param stats_report_interval: interval to collect and log statistics about runtime performance
  34. """
  35. SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
  36. def __init__(
  37. self,
  38. module_backends: Dict[str, ModuleBackend],
  39. prefetch_batches=64,
  40. sender_threads: int = 1,
  41. device: torch.device = None,
  42. stats_report_interval: Optional[int] = None,
  43. ):
  44. super().__init__()
  45. self.module_backends = module_backends
  46. self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
  47. self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
  48. self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
  49. self.shutdown_trigger = mp.Event()
  50. self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
  51. self.stats_report_interval = stats_report_interval
  52. if self.stats_report_interval is not None:
  53. self.stats_reporter = StatsReporter(self.stats_report_interval)
  54. def run(self):
  55. for pool in self.pools:
  56. if not pool.is_alive():
  57. pool.start()
  58. if self.device is not None:
  59. for backend in self.module_backends.values():
  60. backend.module.to(self.device)
  61. with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
  62. try:
  63. self.ready.set()
  64. if self.stats_report_interval is not None:
  65. self.stats_reporter.start()
  66. logger.info("Started")
  67. batch_iterator = self.iterate_minibatches_from_pools()
  68. if self.prefetch_batches > 0:
  69. batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
  70. for pool, batch_index, batch in batch_iterator:
  71. logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
  72. start = time()
  73. try:
  74. outputs = pool.process_func(*batch)
  75. batch_processing_time = time() - start
  76. batch_size = outputs[0].size(0)
  77. logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
  78. if self.stats_report_interval is not None:
  79. self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
  80. output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
  81. except KeyboardInterrupt:
  82. raise
  83. except BaseException as exception:
  84. logger.exception(f"Caught {exception}, attempting to recover")
  85. output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
  86. finally:
  87. if not self.shutdown_trigger.is_set():
  88. self.shutdown()
  89. def shutdown(self):
  90. """Gracefully terminate a running runtime."""
  91. logger.info("Shutting down")
  92. self.ready.clear()
  93. if self.stats_report_interval is not None:
  94. self.stats_reporter.stop.set()
  95. self.stats_reporter.join()
  96. logger.debug("Terminating pools")
  97. for pool in self.pools:
  98. if pool.is_alive():
  99. pool.terminate()
  100. pool.join()
  101. logger.debug("Pools terminated")
  102. # trigger background thread to shutdown
  103. self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
  104. self.shutdown_trigger.set()
  105. def iterate_minibatches_from_pools(self, timeout=None):
  106. """Iteratively select non-empty pool with highest priority and loads a batch from that pool"""
  107. with DefaultSelector() as selector:
  108. for pool in self.pools:
  109. selector.register(pool.batch_receiver, EVENT_READ, pool)
  110. selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
  111. while True:
  112. # wait until at least one batch_receiver becomes available
  113. logger.debug("Waiting for inputs from task pools")
  114. ready_fds = selector.select()
  115. ready_objects = {key.data for (key, events) in ready_fds}
  116. if self.SHUTDOWN_TRIGGER in ready_objects:
  117. break # someone asked us to shutdown, break from the loop
  118. logger.debug("Choosing the pool with first priority")
  119. pool = min(ready_objects, key=lambda pool: pool.priority)
  120. logger.debug(f"Loading batch from {pool.name}")
  121. batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
  122. logger.debug(f"Loaded batch from {pool.name}")
  123. yield pool, batch_index, batch_tensors
  124. BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
  125. class StatsReporter(threading.Thread):
  126. def __init__(self, report_interval: int):
  127. super().__init__()
  128. self.report_interval = report_interval
  129. self.stop = threading.Event()
  130. self.stats_queue = SimpleQueue()
  131. def run(self):
  132. while not self.stop.wait(self.report_interval):
  133. pool_batch_stats = defaultdict(list)
  134. while not self.stats_queue.empty():
  135. pool_uid, batch_stats = self.stats_queue.get()
  136. pool_batch_stats[pool_uid].append(batch_stats)
  137. total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
  138. logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
  139. for pool_uid, pool_stats in pool_batch_stats.items():
  140. total_batches = len(pool_stats)
  141. total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
  142. avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
  143. total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
  144. batches_to_time = total_batches / total_time
  145. batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
  146. examples_to_time = total_examples / total_time
  147. example_performance = f"{examples_to_time:.2f} " + (
  148. "examples/s" if examples_to_time > 1 else "s/example"
  149. )
  150. logger.info(
  151. f"{pool_uid}: "
  152. f"{total_batches} batches ({batch_performance}), "
  153. f"{total_examples} examples ({example_performance}), "
  154. f"avg batch size {avg_batch_size:.2f}"
  155. )
  156. def report_stats(self, pool_uid, batch_size, processing_time):
  157. batch_stats = BatchStats(batch_size, processing_time)
  158. self.stats_queue.put_nowait((pool_uid, batch_stats))