runtime.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import multiprocessing as mp
  2. import multiprocessing.pool
  3. import threading
  4. from collections import defaultdict
  5. from itertools import chain
  6. from queue import SimpleQueue
  7. from selectors import EVENT_READ, DefaultSelector
  8. from statistics import mean
  9. from time import time
  10. from typing import Dict, NamedTuple, Optional
  11. import torch
  12. from hivemind.moe.server.module_backend import ModuleBackend
  13. from hivemind.utils import get_logger
  14. from prefetch_generator import BackgroundGenerator
  15. logger = get_logger(__name__)
  16. class Runtime(threading.Thread):
  17. """
  18. A group of processes that processes incoming requests for multiple module backends on a shared device.
  19. Runtime is usually created and managed by Server, humans need not apply.
  20. For debugging, you can start runtime manually with .start() or .run()
  21. >>> module_backends = {'block_uid': ModuleBackend(**kwargs)}
  22. >>> runtime = Runtime(module_backends)
  23. >>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run()
  24. >>> runtime.ready.wait() # await for runtime to load all blocks on device and create request pools
  25. >>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs)
  26. >>> print("Returned:", future.result())
  27. >>> runtime.shutdown()
  28. :param module_backends: a dict [block uid -> ModuleBackend]
  29. :param prefetch_batches: form up to this many batches in advance
  30. :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
  31. :param device: if specified, moves all blocks and data to this device via .to(device=device).
  32. If you want to manually specify devices for each block (in their forward pass), leave device=None (default)
  33. :param stats_report_interval: interval to collect and log statistics about runtime performance
  34. """
  35. SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
  36. def __init__(
  37. self,
  38. module_backends: Dict[str, ModuleBackend],
  39. prefetch_batches: int = 1,
  40. sender_threads: int = 1,
  41. device: torch.device = None,
  42. stats_report_interval: Optional[int] = None,
  43. ):
  44. super().__init__()
  45. self.module_backends = module_backends
  46. self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
  47. self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
  48. self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
  49. self.shutdown_trigger = mp.Event()
  50. self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
  51. self.stats_report_interval = stats_report_interval
  52. if self.stats_report_interval is not None:
  53. self.stats_reporter = StatsReporter(self.stats_report_interval)
  54. def run(self):
  55. for pool in self.pools:
  56. if not pool.is_alive():
  57. pool.start()
  58. if self.device is not None:
  59. for backend in self.module_backends.values():
  60. backend.module.to(self.device)
  61. with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
  62. try:
  63. self.ready.set()
  64. if self.stats_report_interval is not None:
  65. self.stats_reporter.start()
  66. logger.info("Started")
  67. batch_iterator = self.iterate_minibatches_from_pools()
  68. if self.prefetch_batches > 0:
  69. batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
  70. for pool, batch_index, batch in batch_iterator:
  71. logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
  72. start = time()
  73. try:
  74. outputs = pool.process_func(*batch)
  75. output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
  76. batch_processing_time = time() - start
  77. batch_size = outputs[0].size(0)
  78. logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
  79. if self.stats_report_interval is not None:
  80. self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
  81. except KeyboardInterrupt:
  82. raise
  83. except BaseException as exception:
  84. logger.exception(f"Caught {exception}, attempting to recover")
  85. output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
  86. finally:
  87. if not self.shutdown_trigger.is_set():
  88. self.shutdown()
  89. def shutdown(self):
  90. """Gracefully terminate a running runtime."""
  91. logger.info("Shutting down")
  92. self.ready.clear()
  93. if self.stats_report_interval is not None:
  94. self.stats_reporter.stop.set()
  95. self.stats_reporter.join()
  96. logger.debug("Terminating pools")
  97. for pool in self.pools:
  98. if pool.is_alive():
  99. pool.shutdown()
  100. logger.debug("Pools terminated")
  101. # trigger background thread to shutdown
  102. self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
  103. self.shutdown_trigger.set()
  104. def iterate_minibatches_from_pools(self, timeout=None):
  105. """
  106. Chooses pool according to priority, then copies exposed batch and frees the buffer
  107. """
  108. with DefaultSelector() as selector:
  109. for pool in self.pools:
  110. selector.register(pool.batch_receiver, EVENT_READ, pool)
  111. selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
  112. while True:
  113. # wait until at least one batch_receiver becomes available
  114. logger.debug("Waiting for inputs from task pools")
  115. ready_fds = selector.select()
  116. ready_objects = {key.data for (key, events) in ready_fds}
  117. if self.SHUTDOWN_TRIGGER in ready_objects:
  118. break # someone asked us to shutdown, break from the loop
  119. logger.debug("Choosing the pool with first priority")
  120. pool = min(ready_objects, key=lambda pool: pool.priority)
  121. logger.debug(f"Loading batch from {pool.name}")
  122. batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
  123. logger.debug(f"Loaded batch from {pool.name}")
  124. yield pool, batch_index, batch_tensors
  125. BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
  126. class StatsReporter(threading.Thread):
  127. def __init__(self, report_interval: int):
  128. super().__init__()
  129. self.report_interval = report_interval
  130. self.stop = threading.Event()
  131. self.stats_queue = SimpleQueue()
  132. def run(self):
  133. while not self.stop.wait(self.report_interval):
  134. pool_batch_stats = defaultdict(list)
  135. while not self.stats_queue.empty():
  136. pool_uid, batch_stats = self.stats_queue.get()
  137. pool_batch_stats[pool_uid].append(batch_stats)
  138. total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
  139. logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
  140. for pool_uid, pool_stats in pool_batch_stats.items():
  141. total_batches = len(pool_stats)
  142. total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
  143. avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
  144. total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
  145. batches_to_time = total_batches / total_time
  146. batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
  147. examples_to_time = total_examples / total_time
  148. example_performance = f"{examples_to_time:.2f} " + (
  149. "examples/s" if examples_to_time > 1 else "s/example"
  150. )
  151. logger.info(
  152. f"{pool_uid}: "
  153. f"{total_batches} batches ({batch_performance}), "
  154. f"{total_examples} examples ({example_performance}), "
  155. f"avg batch size {avg_batch_size:.2f}"
  156. )
  157. def report_stats(self, pool_uid, batch_size, processing_time):
  158. batch_stats = BatchStats(batch_size, processing_time)
  159. self.stats_queue.put_nowait((pool_uid, batch_stats))