Browse Source

Fix handler memory leak, get rid of mp.Manager (#373)

This PR removes the memory leak from somewhere within handler.py that has something to do with mp.SyncManager.
justheuristic 2 years ago
parent
commit
5a8de2f1f8
2 changed files with 121 additions and 79 deletions
  1. 117 72
      src/petals/server/handler.py
  2. 4 7
      src/petals/server/server.py

+ 117 - 72
src/petals/server/handler.py

@@ -2,9 +2,9 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import contextlib
 import contextlib
-import multiprocessing.managers
+import multiprocessing as mp
 import sys
 import sys
-from concurrent.futures import ThreadPoolExecutor
+from enum import Enum
 from itertools import chain
 from itertools import chain
 from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 
@@ -42,20 +42,15 @@ logger = get_logger(__name__)
 # Fix pickling protobufs, see https://stackoverflow.com/a/74873028
 # Fix pickling protobufs, see https://stackoverflow.com/a/74873028
 sys.modules["runtime_pb2"] = runtime_pb2
 sys.modules["runtime_pb2"] = runtime_pb2
 
 
-# Fix queues in multiprocessing.Manager in Python < 3.9.7, see https://bugs.python.org/issue30256
 
 
-_OriginalAutoProxy = multiprocessing.managers.AutoProxy
-
-
-def patched_autoproxy(*args, manager_owned=True, **kwargs):
-    # Calling original AutoProxy without the unwanted key argument
-    return _OriginalAutoProxy(*args, **kwargs)
-
-
-multiprocessing.managers.AutoProxy = patched_autoproxy
+CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
 
 
 
 
-CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
+class Event(Enum):
+    NEW_SESSION = 0
+    END_SESSION = 1
+    PUSH = 2
+    SHUTDOWN = 3
 
 
 
 
 class TransformerConnectionHandler(ConnectionHandler):
 class TransformerConnectionHandler(ConnectionHandler):
@@ -70,8 +65,8 @@ class TransformerConnectionHandler(ConnectionHandler):
         *,
         *,
         adapters: Optional[Sequence[str]],
         adapters: Optional[Sequence[str]],
         dht_prefix: str,
         dht_prefix: str,
-        push_manager: multiprocessing.managers.SyncManager,
-        session_queues: Dict[str, multiprocessing.managers.BaseProxy],  # BaseProxy for queue.Queue
+        handler_event_queues: Sequence[mp.Queue],
+        handler_index: int,
         inference_max_length: int,
         inference_max_length: int,
         request_timeout: float,
         request_timeout: float,
         session_timeout: float,
         session_timeout: float,
@@ -83,18 +78,28 @@ class TransformerConnectionHandler(ConnectionHandler):
             assert isinstance(module_backend, TransformerBackend)
             assert isinstance(module_backend, TransformerBackend)
         self.dht_prefix = dht_prefix
         self.dht_prefix = dht_prefix
         self.adapters = adapters
         self.adapters = adapters
-        self._push_manager = push_manager
-        self._session_queues = session_queues
-        self._executor = ThreadPoolExecutor(max_workers=float("inf"))  # For waiting on self.session_queues
+        self._handler_event_queues = handler_event_queues
+        self._handler_index = handler_index
+        self._own_event_queue = handler_event_queues[handler_index]
+        self._listener_task: Optional[asyncio.Task] = None
+        self._session_queues: Dict[str, asyncio.Queue] = {}
+        self._session_handlers: Dict[str, int] = {}
 
 
         self.inference_max_length = inference_max_length
         self.inference_max_length = inference_max_length
         self.request_timeout = request_timeout
         self.request_timeout = request_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
         self._prioritizer = task_prioritizer
         self._prioritizer = task_prioritizer
 
 
+    async def add_p2p_handlers(self, *args, **kwargs) -> None:
+        if self._listener_task is None:
+            # Start listening to our own event queue before we accept any requests
+            self._listener_task = asyncio.create_task(self._listen_to_event_queue())
+        await super().add_p2p_handlers(*args, **kwargs)
+
     def shutdown(self):
     def shutdown(self):
         if self.is_alive():
         if self.is_alive():
             self._outer_pipe.send("_shutdown")
             self._outer_pipe.send("_shutdown")
+            self._own_event_queue.put((Event.SHUTDOWN, None, None))
             self.join(self.shutdown_timeout)
             self.join(self.shutdown_timeout)
             if self.is_alive():
             if self.is_alive():
                 logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
                 logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
@@ -129,7 +134,6 @@ class TransformerConnectionHandler(ConnectionHandler):
         context: P2PContext,
         context: P2PContext,
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
-
         async with timeout(self.session_timeout):
         async with timeout(self.session_timeout):
             try:
             try:
                 request = await asyncio.wait_for(anext(requests), self.step_timeout)
                 request = await asyncio.wait_for(anext(requests), self.step_timeout)
@@ -146,7 +150,6 @@ class TransformerConnectionHandler(ConnectionHandler):
                 active_adapter = self._get_active_adapter(metadata)
                 active_adapter = self._get_active_adapter(metadata)
                 points = metadata.get("points", 0)
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
                 session_id = metadata.get("session_id")
-
                 if not requested_uids:
                 if not requested_uids:
                     raise ValueError("User must specify at least one block for inference, but got none")
                     raise ValueError("User must specify at least one block for inference, but got none")
                 assert isinstance(
                 assert isinstance(
@@ -235,6 +238,56 @@ class TransformerConnectionHandler(ConnectionHandler):
             finally:
             finally:
                 self._log_request("rpc_inference.close", requested_uids, context)
                 self._log_request("rpc_inference.close", requested_uids, context)
 
 
+    @contextlib.contextmanager
+    def _managed_session(self, session_id: str):
+        assert session_id not in self._session_queues, f"session id {session_id} is not unique"
+        try:
+            self._session_queues[session_id] = asyncio.Queue()
+            self._session_handlers[session_id] = self._handler_index
+            for other_index, other_queue in enumerate(self._handler_event_queues):
+                if other_index != self._handler_index:
+                    other_queue.put_nowait((Event.NEW_SESSION, session_id, self._handler_index))
+            yield
+        finally:
+            self._session_queues.pop(session_id).put_nowait(None)  # put None so that the get task will not hang
+            del self._session_handlers[session_id]
+            for other_index, other_queue in enumerate(self._handler_event_queues):
+                if other_index != self._handler_index:
+                    other_queue.put_nowait((Event.END_SESSION, session_id, self._handler_index))
+
+    def _put_into_session_queue(self, session_id: str, request: runtime_pb2.ExpertRequest):
+        handler_index = self._session_handlers.get(session_id)
+        if handler_index is None:
+            logger.debug(f"Ignored rpc_push to unknown session ID: {session_id}")
+        elif handler_index == self._handler_index:
+            self._session_queues[session_id].put_nowait(request)
+        else:
+            self._handler_event_queues[handler_index].put_nowait((Event.PUSH, session_id, request))
+
+    async def _get_from_session_queue(self, session_id: str) -> Optional[runtime_pb2.ExpertRequest]:
+        assert self._session_handlers[session_id] == self._handler_index, "session belongs to another handler"
+        return await self._session_queues[session_id].get()
+
+    async def _listen_to_event_queue(self):
+        loop = asyncio.get_event_loop()
+        while True:
+            try:
+                event, session_id, payload = await loop.run_in_executor(None, self._own_event_queue.get)
+                if event == Event.SHUTDOWN:
+                    break
+                elif event == Event.NEW_SESSION:
+                    self._session_handlers[session_id] = payload  # index of the handler that owns that session
+                elif event == Event.END_SESSION:
+                    self._session_handlers.pop(session_id, None)
+                elif event == Event.PUSH:
+                    maybe_session_queue = self._session_queues.get(session_id)
+                    if maybe_session_queue is not None:
+                        maybe_session_queue.put_nowait(payload)
+                else:
+                    raise RuntimeError(f"Unexpected event: {event}")
+            except Exception as e:
+                logger.exception(e)
+
     async def _iterate_inference_steps(
     async def _iterate_inference_steps(
         self,
         self,
         first_request: runtime_pb2.ExpertRequest,
         first_request: runtime_pb2.ExpertRequest,
@@ -243,67 +296,60 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids: Sequence[str],
         requested_uids: Sequence[str],
         context: P2PContext,
         context: P2PContext,
     ) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]:
     ) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]:
-        loop = asyncio.get_event_loop()
-        if session_id is not None:
-            push_queue = self._push_manager.Queue()
-            self._session_queues[session_id] = push_queue
-
         processed_step_ids = set()
         processed_step_ids = set()
         n_pushes = n_late_pushes = 0
         n_pushes = n_late_pushes = 0
         request = first_request
         request = first_request
         anext_task = get_push_task = None
         anext_task = get_push_task = None
         try:
         try:
-            while request.tensors:  # iterate while user is willing to supply tensors
-                metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-                step_id = metadata.get("step_id")
-
-                pushed = metadata.get("pushed")
-                if pushed:
-                    n_pushes += 1
-
-                if step_id is None or step_id not in processed_step_ids:
-                    yield request, metadata
-                    if step_id is not None:
-                        processed_step_ids.add(step_id)
-                elif pushed:
-                    n_late_pushes += 1
-                    self._log_request(
-                        "rpc_inference.push",
-                        requested_uids,
-                        context,
-                        warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time",
+            with self._managed_session(session_id) if session_id is not None else contextlib.nullcontext():
+                while request.tensors:  # iterate while user is willing to supply tensors
+                    metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+                    step_id = metadata.get("step_id")
+
+                    pushed = metadata.get("pushed")
+                    if pushed:
+                        n_pushes += 1
+                        self._log_request("rpc_inference.push", requested_uids, context, debug=f"session received push")
+
+                    if step_id is None or step_id not in processed_step_ids:
+                        yield request, metadata
+                        if step_id is not None:
+                            processed_step_ids.add(step_id)
+                    elif pushed:
+                        n_late_pushes += 1
+                        self._log_request(
+                            "rpc_inference.push",
+                            requested_uids,
+                            context,
+                            warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time",
+                        )
+
+                    # Wait for the next request, coming either from the `requests` iterator or `push_queue`
+                    if anext_task is None:
+                        anext_task = asyncio.create_task(anext(requests))
+                    if get_push_task is None:
+                        if session_id is not None:
+                            get_push_task = asyncio.create_task(self._get_from_session_queue(session_id))
+                        else:
+                            get_push_task = asyncio.create_task(asyncio.Event().wait())  # Dummy never-ending task
+                    done, _ = await asyncio.wait(
+                        [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED
                     )
                     )
 
 
-                # Wait for the next request, coming either from the `requests` iterator or `push_queue`
-                if anext_task is None:
-                    anext_task = asyncio.create_task(anext(requests))
-                if get_push_task is None:
-                    if session_id is not None:
-                        get_push_task = loop.run_in_executor(self._executor, push_queue.get)
+                    if anext_task in done:
+                        request = await anext_task
+                        anext_task = None
+                    elif get_push_task in done:
+                        request = await get_push_task
+                        get_push_task = None
                     else:
                     else:
-                        get_push_task = asyncio.create_task(asyncio.Event().wait())  # Dummy never-ending task
-                done, _ = await asyncio.wait(
-                    [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED
-                )
-
-                if anext_task in done:
-                    request = await anext_task
-                    anext_task = None
-                elif get_push_task in done:
-                    request = await get_push_task
-                    get_push_task = None
-                else:
-                    self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
-                    anext_task.cancel()
-                    get_push_task.cancel()
-                    return
+                        self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
+                        anext_task.cancel()
+                        get_push_task.cancel()
+                        return
         except:
         except:
             logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True)
             logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True)
             raise
             raise
-        finally:
-            if session_id is not None:
-                push_queue.put(None)  # Stop thread for get_push_task
-                del self._session_queues[session_id]
 
 
     async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
     async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         """Directly push activation tensors from one server to another"""
         """Directly push activation tensors from one server to another"""
@@ -312,8 +358,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         metadata = MSGPackSerializer.loads(request.metadata)
         metadata = MSGPackSerializer.loads(request.metadata)
         session_id = metadata["session_id"]
         session_id = metadata["session_id"]
         self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}")
         self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}")
-
-        self._session_queues[session_id].put(request)
+        self._put_into_session_queue(session_id, request)
         return runtime_pb2.ExpertResponse()
         return runtime_pb2.ExpertResponse()
 
 
     async def _push_outputs(
     async def _push_outputs(

+ 4 - 7
src/petals/server/server.py

@@ -528,23 +528,21 @@ class ModuleContainer(threading.Thread):
         self.dht, self.module_backends = dht, module_backends
         self.dht, self.module_backends = dht, module_backends
         self.server_info, self.update_period, self.expiration = server_info, update_period, expiration
         self.server_info, self.update_period, self.expiration = server_info, update_period, expiration
 
 
-        self.push_manager = mp.Manager()
-        self.push_manager.__enter__()
-        session_queues = self.push_manager.dict()
+        handler_event_queues = [mp.Queue() for _ in range(num_handlers)]
         self.conn_handlers = [
         self.conn_handlers = [
             TransformerConnectionHandler(
             TransformerConnectionHandler(
                 dht,
                 dht,
                 self.module_backends,
                 self.module_backends,
                 adapters=server_info.adapters,
                 adapters=server_info.adapters,
                 dht_prefix=dht_prefix,
                 dht_prefix=dht_prefix,
-                push_manager=self.push_manager,
-                session_queues=session_queues,
+                handler_event_queues=handler_event_queues,
+                handler_index=i,
                 inference_max_length=inference_max_length,
                 inference_max_length=inference_max_length,
                 request_timeout=request_timeout,
                 request_timeout=request_timeout,
                 session_timeout=session_timeout,
                 session_timeout=session_timeout,
                 step_timeout=step_timeout,
                 step_timeout=step_timeout,
             )
             )
-            for _ in range(num_handlers)
+            for i in range(num_handlers)
         ]
         ]
 
 
         self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
         self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
@@ -607,7 +605,6 @@ class ModuleContainer(threading.Thread):
         logger.debug("Shutting down connection handlers")
         logger.debug("Shutting down connection handlers")
         for handler in self.conn_handlers:
         for handler in self.conn_handlers:
             handler.shutdown()
             handler.shutdown()
-        self.push_manager.__exit__(None, None, None)
 
 
         logger.debug(f"Shutting down pools")
         logger.debug(f"Shutting down pools")
         for pool in self.runtime.pools:
         for pool in self.runtime.pools: