__init__.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import multiprocessing as mp
  2. import threading
  3. from itertools import chain
  4. from selectors import DefaultSelector, EVENT_READ
  5. from typing import Dict
  6. import torch
  7. import tqdm
  8. from prefetch_generator import BackgroundGenerator
  9. from .expert_backend import ExpertBackend
  10. from .task_pool import TaskPool, TaskPoolBase
  11. class TesseractRuntime(threading.Thread):
  12. def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
  13. device: torch.device = None):
  14. """
  15. A group of processes that process tasks for multiple experts on a shared device
  16. :param expert_backends: a dict [expert uid -> ExpertBackend]
  17. :param prefetch_batches: generate up to this many batches in advance
  18. :param start: start runtime immediately (at the end of __init__)
  19. """
  20. super().__init__()
  21. self.expert_backends = expert_backends
  22. self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
  23. self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
  24. self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
  25. self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
  26. def run(self):
  27. progress = tqdm.tqdm(bar_format='{desc}, {rate_fmt}')
  28. for pool in self.pools:
  29. if not pool.is_alive():
  30. pool.start()
  31. if self.device is not None:
  32. for expert_backend in self.expert_backends.values():
  33. expert_backend.to(self.device)
  34. with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
  35. try:
  36. self.ready.set()
  37. for pool, batch_index, batch in BackgroundGenerator(
  38. self.iterate_minibatches_from_pools(), self.prefetch_batches):
  39. outputs = pool.process_func(*batch)
  40. output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
  41. progress.update(len(outputs[0]))
  42. progress.desc = f'{pool.uid=} {len(outputs[0])=}'
  43. finally:
  44. self.shutdown()
  45. SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
  46. def shutdown(self):
  47. """ Trigger runtime to terminate, process-save """
  48. self.ready.clear()
  49. self.shutdown_send.send(self.SHUTDOWN_TRIGGER) # trigger background thread to shutdown
  50. for pool in self.pools:
  51. if pool.is_alive():
  52. pool.terminate()
  53. def iterate_minibatches_from_pools(self, timeout=None):
  54. """
  55. Chooses pool according to priority, then copies exposed batch and frees the buffer
  56. """
  57. with DefaultSelector() as selector:
  58. selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
  59. for pool in self.pools:
  60. selector.register(pool.batch_receiver, EVENT_READ, pool)
  61. while True:
  62. # wait until at least one batch_receiver becomes available
  63. ready_fds = selector.select()
  64. ready_objects = {key.data for (key, events) in ready_fds}
  65. if self.SHUTDOWN_TRIGGER in ready_objects:
  66. break # someone asked us to shutdown, break from the loop
  67. pool = max(ready_objects, key=lambda pool: pool.priority)
  68. batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
  69. yield pool, batch_index, batch_tensors