|
@@ -42,6 +42,7 @@ class _ServerInferenceSession:
|
|
|
inputs_queue: asyncio.Queue,
|
|
|
outputs_aiter: AsyncIterator,
|
|
|
*,
|
|
|
+ timeout: float,
|
|
|
max_length: int,
|
|
|
points: int = 0,
|
|
|
):
|
|
@@ -49,6 +50,7 @@ class _ServerInferenceSession:
|
|
|
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
|
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
|
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
|
|
+ self.timeout = timeout
|
|
|
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
|
|
|
self.stepped = False
|
|
|
self.closed = False
|
|
@@ -63,8 +65,7 @@ class _ServerInferenceSession:
|
|
|
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
|
|
|
timeout,
|
|
|
)
|
|
|
- outputs_stream = aiter_with_timeout(outputs_stream, timeout)
|
|
|
- return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata)
|
|
|
+ return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
|
|
|
|
|
|
@staticmethod
|
|
|
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
|
|
@@ -124,7 +125,7 @@ class _ServerInferenceSession:
|
|
|
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
await self._inputs_queue.put(inputs_serialized)
|
|
|
self.stepped = True
|
|
|
- return await anext(self._outputs_stream)
|
|
|
+ return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
|
|
|
|
|
|
def close(self):
|
|
|
"""Finish a given inference session, close the underlying connection"""
|