Ver código fonte

Fix deadlocks in MemoryCache (#396)

- Fix deadlocks in MemoryCache
- Set default --alloc_timeout to 1 until the MemoryCache update
Alexander Borzunov 2 anos atrás
pai
commit
6e4ebb94d2
2 arquivos alterados com 20 adições e 27 exclusões
  1. 1 1
      src/petals/cli/run_server.py
  2. 19 26
      src/petals/server/memory_cache.py

+ 1 - 1
src/petals/cli/run_server.py

@@ -94,7 +94,7 @@ def main():
     parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
-    parser.add_argument('--alloc_timeout', type=float, default=5,
+    parser.add_argument('--alloc_timeout', type=float, default=1,
                         help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
                              'before rejecting the request')
     parser.add_argument('--revision', type=str, default=None,

+ 19 - 26
src/petals/server/memory_cache.py

@@ -90,7 +90,7 @@ class MemoryCache:
             logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
             yield handles
         finally:
-            await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
+            self._free(max_alloc_size, alloc_task)
 
     @staticmethod
     def get_allocation_size(*descriptors: TensorDescriptor) -> int:
@@ -111,25 +111,19 @@ class MemoryCache:
         async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
             if self.current_size_bytes + alloc_size > self.max_size_bytes:
                 await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
-            async with hivemind.utils.enter_asynchronously(self._lock_metadata):
+            with self._lock_metadata:
                 handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
                 self.current_size_bytes += alloc_size
                 self.handle_counter += len(handles)  # note: this will eventually overflow and it is okay
                 self._pipe_send.send((handles, descriptors))
                 return handles
 
-    async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
-        """
-        This method should be called inside asyncio.shield() because:
-            - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
-            - _schedule_free() must finish freeing memory even in case of cancellation
-        """
-
+    def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None:
         if alloc_task.exception() is not None:
             return
         handles = alloc_task.result()
 
-        async with hivemind.utils.enter_asynchronously(self._lock_metadata):
+        with self._lock_metadata:
             self._pipe_send.send((handles, None))  # signal runtime to free these handles
             self.current_size_bytes -= alloc_size
         self._memory_freed_event.set()
@@ -160,22 +154,21 @@ class MemoryCache:
         assert os.getpid() == self.runtime_pid
         # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
 
-        with self._lock_metadata:
-            # read creation/deletion requests from connection handlers
-            while self._pipe_recv.poll():
-                recv_handles, recv_data = self._pipe_recv.recv()
-                if recv_data is not None:  # create new tensors
-                    assert len(recv_handles) == len(recv_data)
-                    for handle, descr in zip(recv_handles, recv_data):
-                        self._allocated_tensors[handle] = descr.make_zeros()
-                        assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
-                else:  # delete tensors by handle
-                    for handle in recv_handles:
-                        if handle not in self._allocated_tensors:
-                            logger.warning(
-                                f"Sanity check failed: asked to delete handle {handle}, but there is no such handle"
-                            )
-                        self._allocated_tensors.pop(handle, None)
+        # read creation/deletion requests from connection handlers
+        while self._pipe_recv.poll():
+            recv_handles, recv_data = self._pipe_recv.recv()
+            if recv_data is not None:  # create new tensors
+                assert len(recv_handles) == len(recv_data)
+                for handle, descr in zip(recv_handles, recv_data):
+                    self._allocated_tensors[handle] = descr.make_zeros()
+                    assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
+            else:  # delete tensors by handle
+                for handle in recv_handles:
+                    if handle not in self._allocated_tensors:
+                        logger.warning(
+                            f"Sanity check failed: asked to delete handle {handle}, but there is no such handle"
+                        )
+                    self._allocated_tensors.pop(handle, None)
         yield tuple(self._allocated_tensors[handle] for handle in handles)