Browse Source

Merge branch 'master' into simplify-running-loop

Alexander Borzunov 2 years ago
parent
commit
6202733f09
1 changed files with 10 additions and 3 deletions
  1. 10 3
      hivemind/utils/mpfuture.py

+ 10 - 3
hivemind/utils/mpfuture.py

@@ -3,12 +3,12 @@ from __future__ import annotations
 import asyncio
 import concurrent.futures._base as base
 import multiprocessing as mp
-import multiprocessing.connection
 import os
 import threading
 import uuid
 from contextlib import nullcontext
 from enum import Enum, auto
+from multiprocessing.reduction import ForkingPickler
 from typing import Any, Callable, Dict, Generic, Optional, TypeVar
 from weakref import ref
 
@@ -303,7 +303,7 @@ class MPFuture(base.Future, Generic[ResultType]):
     def __getstate__(self):
         return dict(
             _sender_pipe=self._sender_pipe,
-            _shared_state_code=self._shared_state_code,
+            _shared_state_code=ForkingPickler.dumps(self._shared_state_code).tobytes(),
             _origin_pid=self._origin_pid,
             _uid=self._uid,
             _use_lock=self._use_lock,
@@ -313,7 +313,14 @@ class MPFuture(base.Future, Generic[ResultType]):
 
     def __setstate__(self, state):
         self._sender_pipe = state["_sender_pipe"]
-        self._shared_state_code = state["_shared_state_code"]
+        try:
+            self._shared_state_code = ForkingPickler.loads(state["_shared_state_code"])
+        except RuntimeError:
+            # If the origin process garbage-collects all instances of MPFuture using the same shmem buffer,
+            # the underlying buffer is freed, and we will get RuntimeError ("unable to open shared memory object")
+            # here since it is not possible to connect to this buffer anymore. To address this, we just replace
+            # the buffer with a non-shared tensor since the origin process doesn't care about our state anymore.
+            self._shared_state_code = torch.tensor([ALL_STATES.index(base.PENDING)], dtype=torch.uint8)
         self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._use_lock = state["_use_lock"]