Răsfoiți Sursa

Fix timeout on next token

Aleksandr Borzunov 2 ani în urmă
părinte
comite
2fafbaa119
1 a modificat fișierele cu 4 adăugiri și 3 ștergeri
  1. 4 3
      src/client/inference_session.py

+ 4 - 3
src/client/inference_session.py

@@ -42,6 +42,7 @@ class _ServerInferenceSession:
         inputs_queue: asyncio.Queue,
         outputs_aiter: AsyncIterator,
         *,
+        timeout: float,
         max_length: int,
         points: int = 0,
     ):
@@ -49,6 +50,7 @@ class _ServerInferenceSession:
         self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+        self.timeout = timeout
         self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
         self.stepped = False
         self.closed = False
@@ -63,8 +65,7 @@ class _ServerInferenceSession:
             stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
             timeout,
         )
-        outputs_stream = aiter_with_timeout(outputs_stream, timeout)
-        return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata)
+        return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
 
     @staticmethod
     async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
@@ -124,7 +125,7 @@ class _ServerInferenceSession:
         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
         await self._inputs_queue.put(inputs_serialized)
         self.stepped = True
-        return await anext(self._outputs_stream)
+        return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
 
     def close(self):
         """Finish a given inference session, close the underlying connection"""