5
0
Эх сурвалжийг харах

Fix handler memory leak, get rid of mp.Manager (#373)

This PR removes the memory leak from somewhere within handler.py that has something to do with mp.SyncManager.
justheuristic 2 жил өмнө
parent
commit
5a8de2f1f8

+ 117 - 72
src/petals/server/handler.py

@@ -2,9 +2,9 @@ from __future__ import annotations
 
 import asyncio
 import contextlib
-import multiprocessing.managers
+import multiprocessing as mp
 import sys
-from concurrent.futures import ThreadPoolExecutor
+from enum import Enum
 from itertools import chain
 from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
@@ -42,20 +42,15 @@ logger = get_logger(__name__)
 # Fix pickling protobufs, see https://stackoverflow.com/a/74873028
 sys.modules["runtime_pb2"] = runtime_pb2
 
-# Fix queues in multiprocessing.Manager in Python < 3.9.7, see https://bugs.python.org/issue30256
 
-_OriginalAutoProxy = multiprocessing.managers.AutoProxy
-
-
-def patched_autoproxy(*args, manager_owned=True, **kwargs):
-    # Calling original AutoProxy without the unwanted key argument
-    return _OriginalAutoProxy(*args, **kwargs)
-
-
-multiprocessing.managers.AutoProxy = patched_autoproxy
+CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
 
 
-CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
+class Event(Enum):
+    NEW_SESSION = 0
+    END_SESSION = 1
+    PUSH = 2
+    SHUTDOWN = 3
 
 
 class TransformerConnectionHandler(ConnectionHandler):
@@ -70,8 +65,8 @@ class TransformerConnectionHandler(ConnectionHandler):
         *,
         adapters: Optional[Sequence[str]],
         dht_prefix: str,
-        push_manager: multiprocessing.managers.SyncManager,
-        session_queues: Dict[str, multiprocessing.managers.BaseProxy],  # BaseProxy for queue.Queue
+        handler_event_queues: Sequence[mp.Queue],
+        handler_index: int,
         inference_max_length: int,
         request_timeout: float,
         session_timeout: float,
@@ -83,18 +78,28 @@ class TransformerConnectionHandler(ConnectionHandler):
             assert isinstance(module_backend, TransformerBackend)
         self.dht_prefix = dht_prefix
         self.adapters = adapters
-        self._push_manager = push_manager
-        self._session_queues = session_queues
-        self._executor = ThreadPoolExecutor(max_workers=float("inf"))  # For waiting on self.session_queues
+        self._handler_event_queues = handler_event_queues
+        self._handler_index = handler_index
+        self._own_event_queue = handler_event_queues[handler_index]
+        self._listener_task: Optional[asyncio.Task] = None
+        self._session_queues: Dict[str, asyncio.Queue] = {}
+        self._session_handlers: Dict[str, int] = {}
 
         self.inference_max_length = inference_max_length
         self.request_timeout = request_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
         self._prioritizer = task_prioritizer
 
+    async def add_p2p_handlers(self, *args, **kwargs) -> None:
+        if self._listener_task is None:
+            # Start listening to our own event queue before we accept any requests
+            self._listener_task = asyncio.create_task(self._listen_to_event_queue())
+        await super().add_p2p_handlers(*args, **kwargs)
+
     def shutdown(self):
         if self.is_alive():
             self._outer_pipe.send("_shutdown")
+            self._own_event_queue.put((Event.SHUTDOWN, None, None))
             self.join(self.shutdown_timeout)
             if self.is_alive():
                 logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
@@ -129,7 +134,6 @@ class TransformerConnectionHandler(ConnectionHandler):
         context: P2PContext,
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
-
         async with timeout(self.session_timeout):
             try:
                 request = await asyncio.wait_for(anext(requests), self.step_timeout)
@@ -146,7 +150,6 @@ class TransformerConnectionHandler(ConnectionHandler):
                 active_adapter = self._get_active_adapter(metadata)
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
-
                 if not requested_uids:
                     raise ValueError("User must specify at least one block for inference, but got none")
                 assert isinstance(
@@ -235,6 +238,56 @@ class TransformerConnectionHandler(ConnectionHandler):
             finally:
                 self._log_request("rpc_inference.close", requested_uids, context)
 
+    @contextlib.contextmanager
+    def _managed_session(self, session_id: str):
+        assert session_id not in self._session_queues, f"session id {session_id} is not unique"
+        try:
+            self._session_queues[session_id] = asyncio.Queue()
+            self._session_handlers[session_id] = self._handler_index
+            for other_index, other_queue in enumerate(self._handler_event_queues):
+                if other_index != self._handler_index:
+                    other_queue.put_nowait((Event.NEW_SESSION, session_id, self._handler_index))
+            yield
+        finally:
+            self._session_queues.pop(session_id).put_nowait(None)  # put None so that the get task will not hang
+            del self._session_handlers[session_id]
+            for other_index, other_queue in enumerate(self._handler_event_queues):
+                if other_index != self._handler_index:
+                    other_queue.put_nowait((Event.END_SESSION, session_id, self._handler_index))
+
+    def _put_into_session_queue(self, session_id: str, request: runtime_pb2.ExpertRequest):
+        handler_index = self._session_handlers.get(session_id)
+        if handler_index is None:
+            logger.debug(f"Ignored rpc_push to unknown session ID: {session_id}")
+        elif handler_index == self._handler_index:
+            self._session_queues[session_id].put_nowait(request)
+        else:
+            self._handler_event_queues[handler_index].put_nowait((Event.PUSH, session_id, request))
+
+    async def _get_from_session_queue(self, session_id: str) -> Optional[runtime_pb2.ExpertRequest]:
+        assert self._session_handlers[session_id] == self._handler_index, "session belongs to another handler"
+        return await self._session_queues[session_id].get()
+
+    async def _listen_to_event_queue(self):
+        loop = asyncio.get_event_loop()
+        while True:
+            try:
+                event, session_id, payload = await loop.run_in_executor(None, self._own_event_queue.get)
+                if event == Event.SHUTDOWN:
+                    break
+                elif event == Event.NEW_SESSION:
+                    self._session_handlers[session_id] = payload  # index of the handler that owns that session
+                elif event == Event.END_SESSION:
+                    self._session_handlers.pop(session_id, None)
+                elif event == Event.PUSH:
+                    maybe_session_queue = self._session_queues.get(session_id)
+                    if maybe_session_queue is not None:
+                        maybe_session_queue.put_nowait(payload)
+                else:
+                    raise RuntimeError(f"Unexpected event: {event}")
+            except Exception as e:
+                logger.exception(e)
+
     async def _iterate_inference_steps(
         self,
         first_request: runtime_pb2.ExpertRequest,
@@ -243,67 +296,60 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids: Sequence[str],
         context: P2PContext,
     ) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]:
-        loop = asyncio.get_event_loop()
-        if session_id is not None:
-            push_queue = self._push_manager.Queue()
-            self._session_queues[session_id] = push_queue
-
         processed_step_ids = set()
         n_pushes = n_late_pushes = 0
         request = first_request
         anext_task = get_push_task = None
         try:
-            while request.tensors:  # iterate while user is willing to supply tensors
-                metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-                step_id = metadata.get("step_id")
-
-                pushed = metadata.get("pushed")
-                if pushed:
-                    n_pushes += 1
-
-                if step_id is None or step_id not in processed_step_ids:
-                    yield request, metadata
-                    if step_id is not None:
-                        processed_step_ids.add(step_id)
-                elif pushed:
-                    n_late_pushes += 1
-                    self._log_request(
-                        "rpc_inference.push",
-                        requested_uids,
-                        context,
-                        warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time",
+            with self._managed_session(session_id) if session_id is not None else contextlib.nullcontext():
+                while request.tensors:  # iterate while user is willing to supply tensors
+                    metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+                    step_id = metadata.get("step_id")
+
+                    pushed = metadata.get("pushed")
+                    if pushed:
+                        n_pushes += 1
+                        self._log_request("rpc_inference.push", requested_uids, context, debug=f"session received push")
+
+                    if step_id is None or step_id not in processed_step_ids:
+                        yield request, metadata
+                        if step_id is not None:
+                            processed_step_ids.add(step_id)
+                    elif pushed:
+                        n_late_pushes += 1
+                        self._log_request(
+                            "rpc_inference.push",
+                            requested_uids,
+                            context,
+                            warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time",
+                        )
+
+                    # Wait for the next request, coming either from the `requests` iterator or `push_queue`
+                    if anext_task is None:
+                        anext_task = asyncio.create_task(anext(requests))
+                    if get_push_task is None:
+                        if session_id is not None:
+                            get_push_task = asyncio.create_task(self._get_from_session_queue(session_id))
+                        else:
+                            get_push_task = asyncio.create_task(asyncio.Event().wait())  # Dummy never-ending task
+                    done, _ = await asyncio.wait(
+                        [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED
                     )
 
-                # Wait for the next request, coming either from the `requests` iterator or `push_queue`
-                if anext_task is None:
-                    anext_task = asyncio.create_task(anext(requests))
-                if get_push_task is None:
-                    if session_id is not None:
-                        get_push_task = loop.run_in_executor(self._executor, push_queue.get)
+                    if anext_task in done:
+                        request = await anext_task
+                        anext_task = None
+                    elif get_push_task in done:
+                        request = await get_push_task
+                        get_push_task = None
                     else:
-                        get_push_task = asyncio.create_task(asyncio.Event().wait())  # Dummy never-ending task
-                done, _ = await asyncio.wait(
-                    [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED
-                )
-
-                if anext_task in done:
-                    request = await anext_task
-                    anext_task = None
-                elif get_push_task in done:
-                    request = await get_push_task
-                    get_push_task = None
-                else:
-                    self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
-                    anext_task.cancel()
-                    get_push_task.cancel()
-                    return
+                        self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
+                        anext_task.cancel()
+                        get_push_task.cancel()
+                        return
         except:
             logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True)
             raise
-        finally:
-            if session_id is not None:
-                push_queue.put(None)  # Stop thread for get_push_task
-                del self._session_queues[session_id]
 
     async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         """Directly push activation tensors from one server to another"""
@@ -312,8 +358,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         metadata = MSGPackSerializer.loads(request.metadata)
         session_id = metadata["session_id"]
         self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}")
-
-        self._session_queues[session_id].put(request)
+        self._put_into_session_queue(session_id, request)
         return runtime_pb2.ExpertResponse()
 
     async def _push_outputs(

+ 4 - 7
src/petals/server/server.py

@@ -528,23 +528,21 @@ class ModuleContainer(threading.Thread):
         self.dht, self.module_backends = dht, module_backends
         self.server_info, self.update_period, self.expiration = server_info, update_period, expiration
 
-        self.push_manager = mp.Manager()
-        self.push_manager.__enter__()
-        session_queues = self.push_manager.dict()
+        handler_event_queues = [mp.Queue() for _ in range(num_handlers)]
         self.conn_handlers = [
             TransformerConnectionHandler(
                 dht,
                 self.module_backends,
                 adapters=server_info.adapters,
                 dht_prefix=dht_prefix,
-                push_manager=self.push_manager,
-                session_queues=session_queues,
+                handler_event_queues=handler_event_queues,
+                handler_index=i,
                 inference_max_length=inference_max_length,
                 request_timeout=request_timeout,
                 session_timeout=session_timeout,
                 step_timeout=step_timeout,
             )
-            for _ in range(num_handlers)
+            for i in range(num_handlers)
         ]
 
         self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
@@ -607,7 +605,6 @@ class ModuleContainer(threading.Thread):
         logger.debug("Shutting down connection handlers")
         for handler in self.conn_handlers:
             handler.shutdown()
-        self.push_manager.__exit__(None, None, None)
 
         logger.debug(f"Shutting down pools")
         for pool in self.runtime.pools: