Browse Source

Fix "unable to open shared memory" while using MPFuture (#517)

Currently, one may sometimes get the "unable to open shared memory" error (see the screenshot) while using `hivemind.MPFuture`. Interestingly, the smaller `HIVEMIND_SHM_BUFFER_SIZE` is, the more often the error occurs (e.g., in Petals, it occurs right after starting the server if `HIVEMIND_SHM_BUFFER_SIZE=2`).

Turns out, it happens when the origin process garbage-collects all instances of MPFuture using the same shmem buffer, then the underlying buffer is freed, and target processes can't reconnect to it anymore when unpickling its instances of MPFuture.

This PR fixes this important issue.

(cherry picked from commit 94c985d2dc7a79a091e46c755e9f2f4469b164c7)
Alexander Borzunov 2 năm trước cách đây
mục cha
commit
ad8063c254
1 tập tin đã thay đổi với 10 bổ sung3 xóa
  1. 10 3
      hivemind/utils/mpfuture.py

+ 10 - 3
hivemind/utils/mpfuture.py

@@ -3,12 +3,12 @@ from __future__ import annotations
 import asyncio
 import concurrent.futures._base as base
 import multiprocessing as mp
-import multiprocessing.connection
 import os
 import threading
 import uuid
 from contextlib import nullcontext
 from enum import Enum, auto
+from multiprocessing.reduction import ForkingPickler
 from typing import Any, Callable, Dict, Generic, Optional, TypeVar
 from weakref import ref
 
@@ -303,7 +303,7 @@ class MPFuture(base.Future, Generic[ResultType]):
     def __getstate__(self):
         return dict(
             _sender_pipe=self._sender_pipe,
-            _shared_state_code=self._shared_state_code,
+            _shared_state_code=ForkingPickler.dumps(self._shared_state_code).tobytes(),
             _origin_pid=self._origin_pid,
             _uid=self._uid,
             _use_lock=self._use_lock,
@@ -313,7 +313,14 @@ class MPFuture(base.Future, Generic[ResultType]):
 
     def __setstate__(self, state):
         self._sender_pipe = state["_sender_pipe"]
-        self._shared_state_code = state["_shared_state_code"]
+        try:
+            self._shared_state_code = ForkingPickler.loads(state["_shared_state_code"])
+        except RuntimeError:
+            # If the origin process garbage-collects all instances of MPFuture using the same shmem buffer,
+            # the underlying buffer is freed, and we will get RuntimeError ("unable to open shared memory object")
+            # here since it is not possible to connect to this buffer anymore. To address this, we just replace
+            # the buffer with a non-shared tensor since the origin process doesn't care about our state anymore.
+            self._shared_state_code = torch.tensor([ALL_STATES.index(base.PENDING)], dtype=torch.uint8)
         self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._use_lock = state["_use_lock"]