Browse Source

Fix "could not unlink the shared memory file" during rebalancing (#135)

Alexander Borzunov 2 years ago
parent
commit
77a00e17f0
1 changed files with 9 additions and 2 deletions
  1. 9 2
      src/petals/server/task_pool.py

+ 9 - 2
src/petals/server/task_pool.py

@@ -2,13 +2,15 @@ import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import threading
 import threading
 import time
 import time
+from concurrent.futures._base import PENDING
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
 from queue import PriorityQueue
 from queue import PriorityQueue
 from typing import Any, List, Optional, Sequence, Tuple
 from typing import Any, List, Optional, Sequence, Tuple
 
 
 import torch
 import torch
-from hivemind import MPFuture, get_logger, use_hivemind_log_handler
+from hivemind import get_logger, use_hivemind_log_handler
 from hivemind.moe.server.task_pool import TaskPoolBase
 from hivemind.moe.server.task_pool import TaskPoolBase
+from hivemind.utils.mpfuture import ALL_STATES, MPFuture
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -102,7 +104,12 @@ class PrioritizedTaskPool(TaskPoolBase):
 
 
     def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
     def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
         """Add task to this pool's queue, return Future for its output"""
         """Add task to this pool's queue, return Future for its output"""
-        task = Task(priority, time.monotonic(), MPFuture(), args)
+        future = MPFuture()
+        # Remove shmem from MPFuture. This disables the .cancel() feature but
+        # saves the server from "could not unlink the shared memory file" crashes during rebalancing
+        future._shared_state_code = torch.tensor([ALL_STATES.index(PENDING)], dtype=torch.uint8)
+
+        task = Task(priority, time.monotonic(), future, args)
         if self.get_task_size(task) > self.max_batch_size:
         if self.get_task_size(task) > self.max_batch_size:
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             task.future.set_exception(exc)
             task.future.set_exception(exc)