|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
|
import asyncio
|
|
|
import concurrent.futures._base as base
|
|
|
import multiprocessing as mp
|
|
|
-import multiprocessing.connection
|
|
|
import os
|
|
|
import threading
|
|
|
import uuid
|
|
@@ -131,13 +130,13 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
|
|
|
self._set_event_threadsafe()
|
|
|
|
|
|
- def _set_event_threadsafe(self):
|
|
|
+ def _set_event_threadsafe(self) -> None:
|
|
|
try:
|
|
|
running_loop = asyncio.get_running_loop()
|
|
|
except RuntimeError:
|
|
|
running_loop = None
|
|
|
|
|
|
- async def _event_setter():
|
|
|
+ async def _event_setter() -> None:
|
|
|
self._aio_event.set()
|
|
|
|
|
|
if self._loop.is_closed():
|
|
@@ -150,7 +149,7 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
self._loop.run_until_complete(_event_setter())
|
|
|
|
|
|
@classmethod
|
|
|
- def _initialize_mpfuture_backend(cls):
|
|
|
+ def _initialize_mpfuture_backend(cls) -> None:
|
|
|
pid = os.getpid()
|
|
|
logger.debug(f"Initializing MPFuture backend for pid {pid}")
|
|
|
|
|
@@ -162,7 +161,7 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
cls._pipe_waiter_thread.start()
|
|
|
|
|
|
@staticmethod
|
|
|
- def reset_backend():
|
|
|
+ def reset_backend() -> None:
|
|
|
"""Last-resort function to reset internals of MPFuture. All current MPFuture instances will be broken"""
|
|
|
MPFuture._active_pid = None
|
|
|
MPFuture._initialization_lock = mp.Lock()
|
|
@@ -200,7 +199,7 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
except Exception as e:
|
|
|
logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
|
|
|
|
|
|
- def _send_update(self, update_type: UpdateType, payload: Any = None):
|
|
|
+ def _send_update(self, update_type: UpdateType, payload: Any = None) -> None:
|
|
|
"""This method sends result, exception or cancel to the MPFuture origin."""
|
|
|
try:
|
|
|
with MPFuture._update_lock if self._use_lock else nullcontext():
|
|
@@ -208,7 +207,7 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
except (ConnectionError, BrokenPipeError, EOFError, OSError) as e:
|
|
|
logger.debug(f"No updates were sent: pipe to origin process was broken ({e})", exc_info=True)
|
|
|
|
|
|
- def set_result(self, result: ResultType):
|
|
|
+ def set_result(self, result: ResultType) -> None:
|
|
|
if os.getpid() == self._origin_pid:
|
|
|
super().set_result(result)
|
|
|
MPFuture._active_futures.pop(self._uid, None)
|
|
@@ -218,7 +217,7 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
self._state_cache[self._state], self._result = base.FINISHED, result
|
|
|
self._send_update(UpdateType.RESULT, result)
|
|
|
|
|
|
- def set_exception(self, exception: Optional[BaseException]):
|
|
|
+ def set_exception(self, exception: Optional[BaseException]) -> None:
|
|
|
if os.getpid() == self._origin_pid:
|
|
|
super().set_exception(exception)
|
|
|
MPFuture._active_futures.pop(self._uid, None)
|
|
@@ -239,7 +238,7 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
self._send_update(UpdateType.CANCEL)
|
|
|
return True
|
|
|
|
|
|
- def set_running_or_notify_cancel(self):
|
|
|
+ def set_running_or_notify_cancel(self) -> bool:
|
|
|
if self._state == base.PENDING:
|
|
|
self._state = base.RUNNING
|
|
|
return True
|
|
@@ -274,18 +273,18 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
def done(self) -> bool:
|
|
|
return self._state in TERMINAL_STATES
|
|
|
|
|
|
- def running(self):
|
|
|
+ def running(self) -> bool:
|
|
|
return self._state == base.RUNNING
|
|
|
|
|
|
- def cancelled(self):
|
|
|
+ def cancelled(self) -> bool:
|
|
|
return self._state == base.CANCELLED
|
|
|
|
|
|
- def add_done_callback(self, callback: Callable[[MPFuture], None]):
|
|
|
+ def add_done_callback(self, callback: Callable[[MPFuture], None]) -> None:
|
|
|
if os.getpid() != self._origin_pid:
|
|
|
raise RuntimeError("Only the process that created MPFuture can set callbacks")
|
|
|
return super().add_done_callback(callback)
|
|
|
|
|
|
- def __await__(self):
|
|
|
+ def __await__(self) -> Any:
|
|
|
if not self._aio_event:
|
|
|
raise RuntimeError("Can't await: MPFuture was created with no event loop")
|
|
|
yield from self._aio_event.wait().__await__()
|
|
@@ -294,13 +293,13 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
except base.CancelledError:
|
|
|
raise asyncio.CancelledError()
|
|
|
|
|
|
- def __del__(self):
|
|
|
+ def __del__(self) -> None:
|
|
|
if getattr(self, "_origin_pid", None) == os.getpid():
|
|
|
MPFuture._active_futures.pop(self._uid, None)
|
|
|
if getattr(self, "_aio_event", None):
|
|
|
self._aio_event.set()
|
|
|
|
|
|
- def __getstate__(self):
|
|
|
+ def __getstate__(self) -> Dict[str, Any]:
|
|
|
return dict(
|
|
|
_sender_pipe=self._sender_pipe,
|
|
|
_shared_state_code=self._shared_state_code,
|
|
@@ -311,7 +310,7 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
_exception=self._exception,
|
|
|
)
|
|
|
|
|
|
- def __setstate__(self, state):
|
|
|
+ def __setstate__(self, state: Dict[str, Any]) -> None:
|
|
|
self._sender_pipe = state["_sender_pipe"]
|
|
|
self._shared_state_code = state["_shared_state_code"]
|
|
|
self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
|