__init__.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import multiprocessing as mp
  2. import threading
  3. from itertools import chain
  4. from selectors import DefaultSelector, EVENT_READ
  5. from typing import Dict
  6. import torch
  7. import tqdm
  8. from prefetch_generator import BackgroundGenerator
  9. from .expert_backend import ExpertBackend
  10. from .task_pool import TaskPool, TaskPoolBase
  11. class TesseractRuntime(threading.Thread):
  12. """
  13. A group of processes that processes incoming requests for multiple experts on a shared device.
  14. TesseractRuntime is usually created and managed by TesseractServer, humans need not apply.
  15. For debugging, you can start runtime manually with .start() or .run()
  16. >>> expert_backends = {'expert_name': ExpertBackend(**kwargs)}
  17. >>> runtime = TesseractRuntime(expert_backends)
  18. >>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run()
  19. >>> runtime.ready.wait() # await for runtime to load all experts on device and create request pools
  20. >>> future = runtime.expert_backends['expert_name'].forward_pool.submit_task(*expert_inputs)
  21. >>> print("Returned:", future.result())
  22. >>> runtime.shutdown()
  23. :param expert_backends: a dict [expert uid -> ExpertBackend]
  24. :param prefetch_batches: form up to this many batches in advance
  25. :param start: start runtime immediately (at the end of __init__)
  26. :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
  27. :param device: if specified, moves all experts and data to this device via .to(device=device).
  28. If you want to manually specify devices for each expert (in their forward pass), leave device=None (default)
  29. """
  30. def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
  31. device: torch.device = None):
  32. super().__init__()
  33. self.expert_backends = expert_backends
  34. self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
  35. self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
  36. self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
  37. self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
  38. def run(self):
  39. progress = tqdm.tqdm(bar_format='{desc}, {rate_fmt}')
  40. for pool in self.pools:
  41. if not pool.is_alive():
  42. pool.start()
  43. if self.device is not None:
  44. for expert_backend in self.expert_backends.values():
  45. expert_backend.to(self.device)
  46. with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
  47. try:
  48. self.ready.set()
  49. for pool, batch_index, batch in BackgroundGenerator(
  50. self.iterate_minibatches_from_pools(), self.prefetch_batches):
  51. outputs = pool.process_func(*batch)
  52. output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
  53. progress.update(len(outputs[0]))
  54. progress.desc = f'pool.uid={pool.uid} batch_size={len(outputs[0])}'
  55. finally:
  56. self.shutdown()
  57. SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
  58. def shutdown(self):
  59. """ Gracefully terminate a running runtime. """
  60. self.ready.clear()
  61. self.shutdown_send.send(self.SHUTDOWN_TRIGGER) # trigger background thread to shutdown
  62. for pool in self.pools:
  63. if pool.is_alive():
  64. pool.terminate()
  65. def iterate_minibatches_from_pools(self, timeout=None):
  66. """
  67. Chooses pool according to priority, then copies exposed batch and frees the buffer
  68. """
  69. with DefaultSelector() as selector:
  70. selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
  71. for pool in self.pools:
  72. selector.register(pool.batch_receiver, EVENT_READ, pool)
  73. while True:
  74. # wait until at least one batch_receiver becomes available
  75. ready_fds = selector.select()
  76. ready_objects = {key.data for (key, events) in ready_fds}
  77. if self.SHUTDOWN_TRIGGER in ready_objects:
  78. break # someone asked us to shutdown, break from the loop
  79. pool = max(ready_objects, key=lambda pool: pool.priority)
  80. batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
  81. yield pool, batch_index, batch_tensors