runtime.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import multiprocessing as mp
  2. import multiprocessing.pool
  3. import threading
  4. from collections import defaultdict
  5. from itertools import chain
  6. from queue import SimpleQueue
  7. from selectors import DefaultSelector, EVENT_READ
  8. from statistics import mean
  9. from time import time
  10. from typing import Dict, NamedTuple, Optional
  11. import torch
  12. from prefetch_generator import BackgroundGenerator
  13. from hivemind.server.expert_backend import ExpertBackend
  14. from hivemind.utils import get_logger
  15. logger = get_logger(__name__)
  16. class Runtime(threading.Thread):
  17. """
  18. A group of processes that processes incoming requests for multiple experts on a shared device.
  19. Runtime is usually created and managed by Server, humans need not apply.
  20. For debugging, you can start runtime manually with .start() or .run()
  21. >>> expert_backends = {'expert_name': ExpertBackend(**kwargs)}
  22. >>> runtime = Runtime(expert_backends)
  23. >>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run()
  24. >>> runtime.ready.wait() # await for runtime to load all experts on device and create request pools
  25. >>> future = runtime.expert_backends['expert_name'].forward_pool.submit_task(*expert_inputs)
  26. >>> print("Returned:", future.result())
  27. >>> runtime.shutdown()
  28. :param expert_backends: a dict [expert uid -> ExpertBackend]
  29. :param prefetch_batches: form up to this many batches in advance
  30. :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
  31. :param device: if specified, moves all experts and data to this device via .to(device=device).
  32. If you want to manually specify devices for each expert (in their forward pass), leave device=None (default)
  33. :param stats_report_interval: interval to collect and log statistics about runtime performance
  34. """
  35. def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
  36. device: torch.device = None, stats_report_interval: Optional[int] = None):
  37. super().__init__()
  38. self.expert_backends = expert_backends
  39. self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
  40. self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
  41. self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
  42. self.stop = threading.Event()
  43. self.stats_report_interval = stats_report_interval
  44. if self.stats_report_interval is not None:
  45. self.stats_reporter = StatsReporter(self.stats_report_interval)
  46. def run(self):
  47. for pool in self.pools:
  48. if not pool.is_alive():
  49. pool.start()
  50. if self.device is not None:
  51. for expert_backend in self.expert_backends.values():
  52. expert_backend.expert.to(self.device)
  53. with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
  54. try:
  55. self.ready.set()
  56. if self.stats_report_interval is not None:
  57. self.stats_reporter.start()
  58. logger.info("Started")
  59. for pool, batch_index, batch in BackgroundGenerator(
  60. self.iterate_minibatches_from_pools(), self.prefetch_batches):
  61. logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
  62. start = time()
  63. outputs = pool.process_func(*batch)
  64. batch_processing_time = time() - start
  65. batch_size = outputs[0].size(0)
  66. logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
  67. if self.stats_report_interval is not None:
  68. self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
  69. output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
  70. finally:
  71. self.shutdown()
  72. def shutdown(self):
  73. """ Gracefully terminate a running runtime. """
  74. logger.info("Shutting down")
  75. if self.stats_report_interval is not None:
  76. self.stats_reporter.stop.set()
  77. self.stats_reporter.join()
  78. self.stop.set() # trigger background thread to shutdown
  79. logger.debug("Terminating pools")
  80. for pool in self.pools:
  81. if pool.is_alive():
  82. pool.terminate()
  83. pool.join()
  84. logger.debug("Pools terminated")
  85. def iterate_minibatches_from_pools(self, timeout=None):
  86. """
  87. Chooses pool according to priority, then copies exposed batch and frees the buffer
  88. """
  89. with DefaultSelector() as selector:
  90. for pool in self.pools:
  91. selector.register(pool.batch_receiver, EVENT_READ, pool)
  92. while not self.stop.is_set():
  93. # wait until at least one batch_receiver becomes available
  94. logger.debug("Waiting for inputs from task pools")
  95. ready_fds = selector.select()
  96. ready_objects = {key.data for (key, events) in ready_fds}
  97. logger.debug("Choosing the pool with highest priority")
  98. pool = max(ready_objects, key=lambda pool: pool.priority)
  99. logger.debug(f"Loading batch from {pool.name}")
  100. batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
  101. logger.debug(f"Loaded batch from {pool.name}")
  102. yield pool, batch_index, batch_tensors
  103. BatchStats = NamedTuple('BatchStats', (('batch_size', int), ('processing_time', float)))
  104. class StatsReporter(threading.Thread):
  105. def __init__(self, report_interval: int):
  106. super().__init__()
  107. self.report_interval = report_interval
  108. self.stop = threading.Event()
  109. self.stats_queue = SimpleQueue()
  110. def run(self):
  111. while not self.stop.wait(self.report_interval):
  112. pool_batch_stats = defaultdict(list)
  113. while not self.stats_queue.empty():
  114. pool_uid, batch_stats = self.stats_queue.get()
  115. pool_batch_stats[pool_uid].append(batch_stats)
  116. total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
  117. logger.info(f'Processed {total_processed_batches} batches in last {self.report_interval} seconds:')
  118. for pool_uid, pool_stats in pool_batch_stats.items():
  119. total_batches = len(pool_stats)
  120. total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
  121. avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
  122. total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
  123. batches_to_time = total_batches / total_time
  124. batch_performance = f'{batches_to_time:.2f} ' + ('batches/s' if batches_to_time > 1 else 's/batch')
  125. examples_to_time = total_examples / total_time
  126. example_performance = f'{examples_to_time:.2f} ' + (
  127. 'examples/s' if examples_to_time > 1 else 's/example')
  128. logger.info(f'{pool_uid}: '
  129. f'{total_batches} batches ({batch_performance}), '
  130. f'{total_examples} examples ({example_performance}), '
  131. f'avg batch size {avg_batch_size:.2f}')
  132. def report_stats(self, pool_uid, batch_size, processing_time):
  133. batch_stats = BatchStats(batch_size, processing_time)
  134. self.stats_queue.put_nowait((pool_uid, batch_stats))