123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- import multiprocessing as mp
- import threading
- from itertools import chain
- from selectors import DefaultSelector, EVENT_READ
- from typing import Dict
- import torch
- import tqdm
- from prefetch_generator import BackgroundGenerator
- from .expert_backend import ExpertBackend
- from .task_pool import TaskPool, TaskPoolBase
- class TesseractRuntime(threading.Thread):
- def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
- device: torch.device = None):
- """
- A group of processes that process tasks for multiple experts on a shared device
- :param expert_backends: a dict [expert uid -> ExpertBackend]
- :param prefetch_batches: generate up to this many batches in advance
- :param start: start runtime immediately (at the end of __init__)
- """
- super().__init__()
- self.expert_backends = expert_backends
- self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
- self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
- self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
- self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
- def run(self):
- progress = tqdm.tqdm(bar_format='{desc}, {rate_fmt}')
- for pool in self.pools:
- if not pool.is_alive():
- pool.start()
- if self.device is not None:
- for expert_backend in self.expert_backends.values():
- expert_backend.to(self.device)
- with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
- try:
- self.ready.set()
- for pool, batch_index, batch in BackgroundGenerator(
- self.iterate_minibatches_from_pools(), self.prefetch_batches):
- outputs = pool.process_func(*batch)
- output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
- progress.update(len(outputs[0]))
- progress.desc = f'{pool.uid=} {len(outputs[0])=}'
- finally:
- self.shutdown()
- SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
- def shutdown(self):
- """ Trigger runtime to terminate, process-save """
- self.ready.clear()
- self.shutdown_send.send(self.SHUTDOWN_TRIGGER) # trigger background thread to shutdown
- for pool in self.pools:
- if pool.is_alive():
- pool.terminate()
- def iterate_minibatches_from_pools(self, timeout=None):
- """
- Chooses pool according to priority, then copies exposed batch and frees the buffer
- """
- with DefaultSelector() as selector:
- selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
- for pool in self.pools:
- selector.register(pool.batch_receiver, EVENT_READ, pool)
- while True:
- # wait until at least one batch_receiver becomes available
- ready_fds = selector.select()
- ready_objects = {key.data for (key, events) in ready_fds}
- if self.SHUTDOWN_TRIGGER in ready_objects:
- break # someone asked us to shutdown, break from the loop
- pool = max(ready_objects, key=lambda pool: pool.priority)
- batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
- yield pool, batch_index, batch_tensors
|