Prechádzať zdrojové kódy

Fix race condition in MemoryCache (#487)

Alexander Borzunov 2 rokov pred
rodič
commit
02fc71eb25
1 zmenil súbory, kde vykonal 7 pridanie a 4 odobranie
  1. 7 4
      src/petals/server/memory_cache.py

+ 7 - 4
src/petals/server/memory_cache.py

@@ -31,7 +31,7 @@ class MemoryCache:
         self.max_alloc_timeout = max_alloc_timeout
         self._lock_metadata = mp.Lock()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
-        self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=False)
+        self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=True)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
         self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
         self.runtime_pid = os.getpid()
@@ -138,7 +138,8 @@ class MemoryCache:
         start_time = time.perf_counter()
         loop = asyncio.get_event_loop()
 
-        self.enqueued_size_bytes += alloc_size
+        with self._enqueued_size.get_lock():
+            self._enqueued_size.value += alloc_size
         allocated = False
         try:
             context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack()
@@ -155,13 +156,15 @@ class MemoryCache:
                         await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout)
 
                 allocated = True
-                self.enqueued_size_bytes -= alloc_size
+                with self._enqueued_size.get_lock():
+                    self._enqueued_size.value -= alloc_size
                 yield
         except asyncio.TimeoutError:
             raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds")
         finally:
             if not allocated:
-                self.enqueued_size_bytes -= alloc_size
+                with self._enqueued_size.get_lock():
+                    self._enqueued_size.value -= alloc_size
 
     def _free(self, alloc_size: int, alloc_task: asyncio.Task):
         if alloc_task.exception() is not None: