Browse Source

Fix exception in MPFuture.__del__() (#555)

This PR addresses the bug reported in #552 - or, at least, it should, since we cannot reproduce the problem locally.

(cherry picked from commit 8c98caa1bec4664902644fc437c1ab02575e407f)
justheuristic 2 năm trước cách đây
mục cha
commit
c0ffab3a50
1 tập tin đã thay đổi với 19 bổ sung15 xóa
  1. 19 15
      hivemind/utils/mpfuture.py

+ 19 - 15
hivemind/utils/mpfuture.py

@@ -95,6 +95,8 @@ class MPFuture(base.Future, Generic[ResultType]):
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
 
     def __init__(self, *, use_lock: bool = True):
+        self._maybe_initialize_mpfuture_backend()
+
         self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
         self._shared_state_code = SharedBytes.next()
         self._state_cache: Dict[State, State] = {}
@@ -105,11 +107,6 @@ class MPFuture(base.Future, Generic[ResultType]):
         self._state, self._result, self._exception = base.PENDING, None, None
         self._use_lock = use_lock
 
-        if self._origin_pid != MPFuture._active_pid:
-            with MPFuture._initialization_lock:
-                if self._origin_pid != MPFuture._active_pid:
-                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
-                    self._initialize_mpfuture_backend()
         assert self._uid not in MPFuture._active_futures
         MPFuture._active_futures[self._uid] = ref(self)
         self._sender_pipe = MPFuture._global_sender_pipe
@@ -151,16 +148,23 @@ class MPFuture(base.Future, Generic[ResultType]):
             self._loop.run_until_complete(_event_setter())
 
     @classmethod
-    def _initialize_mpfuture_backend(cls):
+    def _maybe_initialize_mpfuture_backend(cls):
         pid = os.getpid()
-        logger.debug(f"Initializing MPFuture backend for pid {pid}")
-
-        receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
-        cls._active_pid, cls._active_futures = pid, {}
-        cls._pipe_waiter_thread = threading.Thread(
-            target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
-        )
-        cls._pipe_waiter_thread.start()
+        if pid != MPFuture._active_pid:
+            with MPFuture._initialization_lock:
+                if pid != MPFuture._active_pid:
+                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
+                    logger.debug(f"Initializing MPFuture backend for pid {pid}")
+
+                    receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
+                    cls._active_pid, cls._active_futures = pid, {}
+                    cls._pipe_waiter_thread = threading.Thread(
+                        target=cls._process_updates_in_background,
+                        args=[receiver_pipe],
+                        name=f"{__name__}.BACKEND",
+                        daemon=True,
+                    )
+                    cls._pipe_waiter_thread.start()
 
     @staticmethod
     def reset_backend():
@@ -296,7 +300,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             raise asyncio.CancelledError()
 
     def __del__(self):
-        if getattr(self, "_origin_pid", None) == os.getpid():
+        if getattr(self, "_origin_pid", None) == os.getpid() and MPFuture._active_futures is not None:
             MPFuture._active_futures.pop(self._uid, None)
         if getattr(self, "_aio_event", None):
             self._aio_event.set()