فهرست منبع

DO NOT MERGE UNDER ANY CIRCUMSTANCES

Aleksandr Borzunov 2 سال پیش
والد
کامیت
74c086ea35
1فایلهای تغییر یافته به همراه13 افزوده شده و 9 حذف شده
  1. 13 9
      src/petals/server/handler.py

+ 13 - 9
src/petals/server/handler.py

@@ -129,6 +129,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         context: P2PContext,
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
         """Compute a single step of inference using attention cache; update attention cache accordingly."""
+        import os, psutil; print(f"handler rpc_inference {os.getpid()} : {psutil.Process().memory_info().rss / 1024 / 1024} mb")
 
         async with timeout(self.session_timeout):
             try:
@@ -304,18 +305,21 @@ class TransformerConnectionHandler(ConnectionHandler):
             if session_id is not None:
                 push_queue.put(None)  # Stop thread for get_push_task
                 del self._session_queues[session_id]
+                print("DELETED SESSION", session_id, flush=True)
 
     async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         """Directly push activation tensors from one server to another"""
-
-        requested_uids = self._check_uids(request.uid)
-        metadata = MSGPackSerializer.loads(request.metadata)
-        session_id = metadata["session_id"]
-        self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}")
-
-        self._session_queues[session_id].put(request)
-        return runtime_pb2.ExpertResponse()
-
+        import os, psutil; print(f"handler rpc_push {os.getpid()} : {psutil.Process().memory_info().rss / 1024 / 1024} mb")
+        try:
+            requested_uids = self._check_uids(request.uid)
+            metadata = MSGPackSerializer.loads(request.metadata)
+            session_id = metadata["session_id"]
+            self._log_request("rpc_push", requested_uids, context, warning=f"session_id={session_id}")
+
+            self._session_queues[session_id].put(request)
+            return runtime_pb2.ExpertResponse()
+        except Exception as e:
+            logger.exception(e)
     async def _push_outputs(
         self, request: runtime_pb2.ExpertRequest, serialized_outputs: runtime_pb2.Tensor, metadata: dict
     ) -> None: