Преглед изворни кода

Show PeerID in TimeoutError, don't show tracebacks

Aleksandr Borzunov пре 2 година
родитељ
комит
ee115dd44e
2 измењених фајлова са 26 додато и 9 уклоњено
  1. 16 7
      src/petals/client/inference_session.py
  2. 10 2
      src/petals/client/remote_forward_backward.py

+ 16 - 7
src/petals/client/inference_session.py

@@ -38,6 +38,7 @@ class _ServerInferenceSession:
     def __init__(
         self,
         uid: ModuleUID,
+        stub: StubBase,
         rpc_info: RPCInfo,
         inputs_queue: asyncio.Queue,
         outputs_aiter: AsyncIterator,
@@ -46,7 +47,7 @@ class _ServerInferenceSession:
         max_length: int,
         **metadata,
     ):
-        self.uid, self.rpc_info = uid, rpc_info
+        self.uid, self.stub, self.rpc_info = uid, stub, rpc_info
         self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
@@ -61,11 +62,15 @@ class _ServerInferenceSession:
     ) -> _ServerInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
         inputs_queue = asyncio.Queue()
-        outputs_stream = await asyncio.wait_for(
-            stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
-            timeout,
-        )
-        return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
+        try:
+            outputs_stream = await asyncio.wait_for(
+                stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
+                timeout,
+            )
+        except asyncio.TimeoutError as e:
+            e.args = (f"Timeout on rpc_inference.open(remote_peer=...{stub._peer[-6:]})",)
+            raise
+        return cls(uid, stub, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
 
     @staticmethod
     async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
@@ -125,7 +130,11 @@ class _ServerInferenceSession:
         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
         await self._inputs_queue.put(inputs_serialized)
         self.stepped = True
-        return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
+        try:
+            return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
+        except asyncio.TimeoutError as e:
+            e.args = (f"Timeout on rpc_inference.step(remote_peer=...{self.stub._peer[-6:]})",)
+            raise
 
     def close(self):
         """Finish a given inference session, close the underlying connection"""

+ 10 - 2
src/petals/client/remote_forward_backward.py

@@ -109,7 +109,11 @@ async def run_remote_forward(
     # call RPC on remote server
     size = sum(t.element_size() * t.nelement() for t in inputs)
     forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _forward_unary
-    deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
+    try:
+        deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
+    except asyncio.TimeoutError as e:
+        e.args = (f"Timeout on rpc_forward(remote_peer=...{stub._peer[-6:]})",)
+        raise
     return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
 
 
@@ -151,5 +155,9 @@ async def run_remote_backward(
 
     size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
     backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _backward_unary
-    deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
+    try:
+        deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
+    except asyncio.TimeoutError as e:
+        e.args = (f"Timeout on rpc_backward(remote_peer=...{stub._peer[-6:]})",)
+        raise
     return deserialized_grad_inputs