remote_expert_worker.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import os
  2. from concurrent.futures import Future
  3. from queue import Queue
  4. from threading import Thread
  5. from typing import Awaitable, Optional
  6. from hivemind.utils import switch_to_uvloop
  7. class RemoteExpertWorker:
  8. """Local thread for managing async tasks related to RemoteExpert"""
  9. _task_queue: Queue = Queue()
  10. _event_thread: Optional[Thread] = None
  11. _pid: int = -1
  12. @classmethod
  13. def _run(cls):
  14. loop = switch_to_uvloop()
  15. async def receive_tasks():
  16. while True:
  17. cor, future = cls._task_queue.get()
  18. try:
  19. result = await cor
  20. except Exception as e:
  21. future.set_exception(e)
  22. continue
  23. if not future.cancelled():
  24. future.set_result(result)
  25. loop.run_until_complete(receive_tasks())
  26. @classmethod
  27. def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
  28. if cls._event_thread is None or cls._pid != os.getpid():
  29. cls._pid = os.getpid()
  30. cls._event_thread = Thread(target=cls._run, daemon=True)
  31. cls._event_thread.start()
  32. future = Future()
  33. cls._task_queue.put((coro, future))
  34. if return_future:
  35. return future
  36. result = future.result()
  37. return result