Pārlūkot izejas kodu

Shield alloc & free from cancellation (#163)

A handler's RPC code may be cancelled due to a request timeout or a client closing the connection. Before this PR:

- If `.cancel()` happens while waiting for `hivemind.utils.enter_asynchronously()`, the lock will never be released.
- If `.cancel()` happens while doing that before freeing memory, the memory will never be freed.

This PR fixes it by deferring the cancellation with [asyncio.shield()](https://docs.python.org/3/library/asyncio-task.html#asyncio.shield). Now, the cancellation will happen only when all locks are released and alloc/free has completed.
Alexander Borzunov 2 gadi atpakaļ
vecāks
revīzija
9997ada3bb

+ 26 - 8
src/petals/server/handler.py

@@ -1,6 +1,6 @@
 import asyncio
 import contextlib
-from typing import Any, AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
+from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 import torch
 from async_timeout import timeout
@@ -93,7 +93,12 @@ class TransformerConnectionHandler(ConnectionHandler):
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
 
         async with timeout(self.session_timeout):
-            request = await asyncio.wait_for(anext(requests), self.step_timeout)
+            try:
+                request = await asyncio.wait_for(anext(requests), self.step_timeout)
+            except asyncio.TimeoutError:
+                self._log_request("rpc_inference.open", None, context, warning="timed out")
+                return
+
             requested_uids = self._check_uids(request.uid)
             self._log_request("rpc_inference.open", requested_uids, context)
             try:
@@ -193,7 +198,11 @@ class TransformerConnectionHandler(ConnectionHandler):
 
                         # prepare for next step
                         prefix_length += hidden_states.shape[1]
-                        request = await asyncio.wait_for(anext(requests), self.step_timeout)
+                        try:
+                            request = await asyncio.wait_for(anext(requests), self.step_timeout)
+                        except asyncio.TimeoutError:
+                            self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
+                            return
             finally:
                 self._log_request("rpc_inference.close", requested_uids, context)
 
@@ -369,14 +378,23 @@ class TransformerConnectionHandler(ConnectionHandler):
             logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)")
             yield handle
 
-    def _log_request(self, method: str, uids: Sequence[ModuleUID], context: P2PContext) -> None:
-        friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
-        friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]
-        friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids
+    def _log_request(
+        self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None
+    ) -> None:
+        if uids is not None:
+            friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
+            friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]
+            friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids
+        else:
+            friendly_uids = "n/a"
 
         friendly_remote_id = "..." + str(context.remote_id)[-6:]
 
-        logger.info(f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})")
+        message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})"
+        if warning is None:
+            logger.info(message)
+        else:
+            logger.warning(f"{message}: {warning}")
 
 
 async def _rpc_forward(

+ 40 - 20
src/petals/server/memory_cache.py

@@ -16,6 +16,8 @@ import hivemind
 import torch
 from hivemind.utils import TensorDescriptor, get_logger
 
+from petals.utils.asyncio import shield_and_wait
+
 logger = get_logger(__file__)
 
 Handle = int
@@ -66,28 +68,46 @@ class MemoryCache:
         """
         assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
         assert descr.device is None and descr
-        allocated_handle = None
-        allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
-        loop = asyncio.get_event_loop()
+
+        alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
+        alloc_task = asyncio.create_task(self._schedule_alloc(alloc_size, descr))
         try:
-            async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
-                if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
-                    await loop.run_in_executor(
-                        None, self._wait_until_available, allocated_size_bytes, self.alloc_timeout
-                    )
-                async with hivemind.utils.enter_asynchronously(self._lock_metadata):
-                    allocated_handle = int(self.handle_counter)
-                    self.current_size_bytes += allocated_size_bytes
-                    self.handle_counter += 1  # note: this will eventually overflow and it is okay
-                    self._pipe_send.send((allocated_handle, descr))
-
-            yield allocated_handle
+            yield await shield_and_wait(alloc_task)
         finally:
-            if allocated_handle is not None:
-                async with hivemind.utils.enter_asynchronously(self._lock_metadata):
-                    self._pipe_send.send((allocated_handle, None))  # signal runtime to free that handle
-                    self.current_size_bytes -= allocated_size_bytes
-                self._memory_freed_event.set()
+            await shield_and_wait(self._schedule_free(alloc_size, alloc_task))
+
+    async def _schedule_alloc(self, alloc_size: int, descr: TensorDescriptor) -> Handle:
+        """
+        This method should be called inside asyncio.shield() because:
+            - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
+        """
+
+        loop = asyncio.get_event_loop()
+        async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
+            if self.current_size_bytes + alloc_size > self.max_size_bytes:
+                await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
+            async with hivemind.utils.enter_asynchronously(self._lock_metadata):
+                handle = int(self.handle_counter)
+                self.current_size_bytes += alloc_size
+                self.handle_counter += 1  # note: this will eventually overflow and it is okay
+                self._pipe_send.send((handle, descr))
+                return handle
+
+    async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
+        """
+        This method should be called inside asyncio.shield() because:
+            - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
+            - _schedule_free() must finish freeing memory even in case of cancellation
+        """
+
+        if alloc_task.exception() is not None:
+            return
+        handle = alloc_task.result()
+
+        async with hivemind.utils.enter_asynchronously(self._lock_metadata):
+            self._pipe_send.send((handle, None))  # signal runtime to free that handle
+            self.current_size_bytes -= alloc_size
+        self._memory_freed_event.set()
 
     def _wait_until_available(self, allocated_size: int, timeout: Optional[float] = None):
         # note: this function should only be called inside _lock_acquire_memory!

+ 21 - 0
src/petals/utils/asyncio.py

@@ -0,0 +1,21 @@
+import asyncio
+
+
+async def shield_and_wait(task):
+    """
+    Works like asyncio.shield(), but waits for the task to finish before raising CancelledError to the caller.
+    """
+
+    if not isinstance(task, asyncio.Task):
+        task = asyncio.create_task(task)
+
+    cancel_exc = None
+    while True:
+        try:
+            result = await asyncio.shield(task)
+            break
+        except asyncio.CancelledError as e:
+            cancel_exc = e
+    if cancel_exc is not None:
+        raise cancel_exc
+    return result