|
@@ -3,12 +3,12 @@ from __future__ import annotations
|
|
|
import asyncio
|
|
|
import concurrent.futures._base as base
|
|
|
import multiprocessing as mp
|
|
|
-import multiprocessing.connection
|
|
|
import os
|
|
|
import threading
|
|
|
import uuid
|
|
|
from contextlib import nullcontext
|
|
|
from enum import Enum, auto
|
|
|
+from multiprocessing.reduction import ForkingPickler
|
|
|
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
|
|
from weakref import ref
|
|
|
|
|
@@ -303,7 +303,7 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
def __getstate__(self):
|
|
|
return dict(
|
|
|
_sender_pipe=self._sender_pipe,
|
|
|
- _shared_state_code=self._shared_state_code,
|
|
|
+ _shared_state_code=ForkingPickler.dumps(self._shared_state_code).tobytes(),
|
|
|
_origin_pid=self._origin_pid,
|
|
|
_uid=self._uid,
|
|
|
_use_lock=self._use_lock,
|
|
@@ -313,7 +313,14 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
|
|
|
def __setstate__(self, state):
|
|
|
self._sender_pipe = state["_sender_pipe"]
|
|
|
- self._shared_state_code = state["_shared_state_code"]
|
|
|
+ try:
|
|
|
+ self._shared_state_code = ForkingPickler.loads(state["_shared_state_code"])
|
|
|
+ except RuntimeError:
|
|
|
+ # If the origin process garbage-collects all instances of MPFuture using the same shmem buffer,
|
|
|
+ # the underlying buffer is freed, and we will get RuntimeError ("unable to open shared memory object")
|
|
|
+ # here since it is not possible to connect to this buffer anymore. To address this, we just replace
|
|
|
+ # the buffer with a non-shared tensor since the origin process doesn't care about our state anymore.
|
|
|
+ self._shared_state_code = torch.tensor([ALL_STATES.index(base.PENDING)], dtype=torch.uint8)
|
|
|
self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
|
|
|
self._result, self._exception = state["_result"], state["_exception"]
|
|
|
self._use_lock = state["_use_lock"]
|