threading.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from concurrent.futures import Future, as_completed
  2. import time
  3. from threading import Thread
  4. from typing import Optional, List
  5. def run_in_background(func: callable, *args, **kwargs) -> Future:
  6. """ run func(*args, **kwargs) in background and return Future for its outputs """
  7. future = Future()
  8. def _run():
  9. try:
  10. future.set_result(func(*args, **kwargs))
  11. except BaseException as e:
  12. future.set_exception(e)
  13. Thread(target=_run).start()
  14. return future
  15. def run_forever(func: callable, *args, **kwargs):
  16. """ A function that runs a :func: in background forever. Returns a future that catches exceptions """
  17. def repeat():
  18. while True:
  19. func(*args, **kwargs)
  20. return run_in_background(repeat)
  21. def run_and_await_k(jobs: List[callable], k: int,
  22. timeout_after_k: Optional[float] = 0, timeout_total: Optional[float] = None):
  23. """
  24. Runs all :jobs: asynchronously, awaits for at least k of them to finish
  25. :param jobs: functions to call asynchronously
  26. :param k: how many functions should finish for call to be successful
  27. :param timeout_after_k: after reaching k finished jobs, wait for this long before cancelling
  28. :param timeout_total: if specified, terminate cancel jobs after this many seconds
  29. :returns: a list of either results or exceptions for each job
  30. """
  31. jobs = list(jobs)
  32. assert k <= len(jobs), f"Can't await {k} out of {len(jobs)} jobs."
  33. start_time = time.time()
  34. future_to_ix = {run_in_background(job): i for i, job in enumerate(jobs)}
  35. outputs = [None] * len(jobs)
  36. success_count = 0
  37. try:
  38. # await first k futures for as long as it takes
  39. for future in as_completed(list(future_to_ix.keys()), timeout=timeout_total):
  40. success_count += int(not future.exception())
  41. outputs[future_to_ix.pop(future)] = future.result() if not future.exception() else future.exception()
  42. if success_count >= k:
  43. break # we have enough futures to succeed
  44. if len(outputs) + len(future_to_ix) < k:
  45. failed = len(jobs) - len(outputs) - len(future_to_ix)
  46. raise ValueError(f"Couldn't get enough results: too many jobs failed ({failed} / {len(outputs)})")
  47. # await stragglers for at most self.timeout_after_k_min or whatever time is left
  48. if timeout_after_k is not None and timeout_total is not None:
  49. time_left = min(timeout_after_k, timeout_total - time.time() + start_time)
  50. else:
  51. time_left = timeout_after_k if timeout_after_k is not None else timeout_total
  52. for future in as_completed(list(future_to_ix.keys()), timeout=time_left):
  53. success_count += int(not future.exception())
  54. outputs[future_to_ix.pop(future)] = future.result() if not future.exception() else future.exception()
  55. except TimeoutError:
  56. if len(outputs) < k:
  57. raise TimeoutError(f"Couldn't get enough results: time limit exceeded (got {len(outputs)} of {k})")
  58. finally:
  59. for future, index in future_to_ix.items():
  60. future.cancel()
  61. outputs[index] = future.result() if not future.exception() else future.exception()
  62. return outputs