Просмотр исходного кода

Fix "unable to open shared memory" while using MPFuture (#517)

Currently, one may sometimes get the "unable to open shared memory" error (see the screenshot) while using `hivemind.MPFuture`. Interestingly, the smaller `HIVEMIND_SHM_BUFFER_SIZE` is, the more often the error occurs (e.g., in Petals, it occurs right after starting the server if `HIVEMIND_SHM_BUFFER_SIZE=2`).

Turns out, it happens when the origin process garbage-collects all instances of MPFuture using the same shmem buffer, then the underlying buffer is freed, and target processes can't reconnect to it anymore when unpickling its instances of MPFuture.

This PR fixes this important issue.
Alexander Borzunov 2 лет назад
Родитель
Сommit
94c985d2dc
1 измененных файлов с 10 добавлено и 3 удалено
  1. 10 3
      hivemind/utils/mpfuture.py

+ 10 - 3
hivemind/utils/mpfuture.py

@@ -3,12 +3,12 @@ from __future__ import annotations
 import asyncio
 import concurrent.futures._base as base
 import multiprocessing as mp
-import multiprocessing.connection
 import os
 import threading
 import uuid
 from contextlib import nullcontext
 from enum import Enum, auto
+from multiprocessing.reduction import ForkingPickler
 from typing import Any, Callable, Dict, Generic, Optional, TypeVar
 from weakref import ref
 
@@ -303,7 +303,7 @@ class MPFuture(base.Future, Generic[ResultType]):
     def __getstate__(self):
         return dict(
             _sender_pipe=self._sender_pipe,
-            _shared_state_code=self._shared_state_code,
+            _shared_state_code=ForkingPickler.dumps(self._shared_state_code).tobytes(),
             _origin_pid=self._origin_pid,
             _uid=self._uid,
             _use_lock=self._use_lock,
@@ -313,7 +313,14 @@ class MPFuture(base.Future, Generic[ResultType]):
 
     def __setstate__(self, state):
         self._sender_pipe = state["_sender_pipe"]
-        self._shared_state_code = state["_shared_state_code"]
+        try:
+            self._shared_state_code = ForkingPickler.loads(state["_shared_state_code"])
+        except RuntimeError:
+            # If the origin process garbage-collects all instances of MPFuture using the same shmem buffer,
+            # the underlying buffer is freed, and we will get RuntimeError ("unable to open shared memory object")
+            # here since it is not possible to connect to this buffer anymore. To address this, we just replace
+            # the buffer with a non-shared tensor since the origin process doesn't care about our state anymore.
+            self._shared_state_code = torch.tensor([ALL_STATES.index(base.PENDING)], dtype=torch.uint8)
         self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._use_lock = state["_use_lock"]