threading.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import os
  2. from concurrent.futures import Future, ThreadPoolExecutor, as_completed, TimeoutError
  3. import time
  4. from typing import Optional, List
  5. GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=os.environ.get("TESSERACT_THREADS", float('inf')))
  6. def run_in_background(func: callable, *args, **kwargs) -> Future:
  7. """ run func(*args, **kwargs) in background and return Future for its outputs """
  8. return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
  9. def run_forever(func: callable, *args, **kwargs):
  10. """ A function that runs a :func: in background forever. Returns a future that catches exceptions """
  11. def repeat():
  12. while True:
  13. func(*args, **kwargs)
  14. return run_in_background(repeat)
  15. def run_and_await_k(jobs: List[callable], k: int,
  16. timeout_after_k: Optional[float] = 0, timeout_total: Optional[float] = None):
  17. """
  18. Runs all :jobs: asynchronously, awaits for at least k of them to finish
  19. :param jobs: functions to call asynchronously
  20. :param k: how many functions should finish for call to be successful
  21. :param timeout_after_k: after reaching k finished jobs, wait for this long before cancelling
  22. :param timeout_total: if specified, terminate cancel jobs after this many seconds
  23. :returns: a list of either results or exceptions for each job
  24. """
  25. jobs = list(jobs)
  26. assert k <= len(jobs), f"Can't await {k} out of {len(jobs)} jobs."
  27. start_time = time.time()
  28. future_to_ix = {run_in_background(job): i for i, job in enumerate(jobs)}
  29. outputs = [None] * len(jobs)
  30. success_count = 0
  31. try:
  32. # await first k futures for as long as it takes
  33. for future in as_completed(list(future_to_ix.keys()), timeout=timeout_total):
  34. success_count += int(not future.exception())
  35. outputs[future_to_ix.pop(future)] = future.result() if not future.exception() else future.exception()
  36. if success_count >= k:
  37. break # we have enough futures to succeed
  38. if len(outputs) + len(future_to_ix) < k:
  39. failed = len(jobs) - len(outputs) - len(future_to_ix)
  40. raise ValueError(f"Couldn't get enough results: too many jobs failed ({failed} / {len(outputs)})")
  41. # await stragglers for at most self.timeout_after_k_min or whatever time is left
  42. if timeout_after_k is not None and timeout_total is not None:
  43. time_left = min(timeout_after_k, timeout_total - time.time() + start_time)
  44. else:
  45. time_left = timeout_after_k if timeout_after_k is not None else timeout_total
  46. for future in as_completed(list(future_to_ix.keys()), timeout=time_left):
  47. success_count += int(not future.exception())
  48. outputs[future_to_ix.pop(future)] = future.result() if not future.exception() else future.exception()
  49. except TimeoutError:
  50. if len(outputs) < k:
  51. raise TimeoutError(f"Couldn't get enough results: time limit exceeded (got {len(outputs)} of {k})")
  52. finally:
  53. for future, index in future_to_ix.items():
  54. future.cancel()
  55. outputs[index] = future.result() if not future.exception() else future.exception()
  56. return outputs