|
@@ -171,7 +171,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
|
|
|
) as cache_handles:
|
|
|
background_tasks = set()
|
|
|
- async for output_tensors, can_push in iterate_rpc_inference(
|
|
|
+ async for output_tensors, can_push, step_metadata in iterate_rpc_inference(
|
|
|
requested_uids=requested_uids,
|
|
|
requested_backends=requested_backends,
|
|
|
active_adapter=self._get_active_adapter(metadata),
|
|
@@ -186,7 +186,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
args_structure=args_structure,
|
|
|
):
|
|
|
if can_push:
|
|
|
- task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
|
|
|
+ task = asyncio.create_task(self._push_outputs(request, output_tensors[0], step_metadata))
|
|
|
background_tasks.add(task) # Keep reference until it is done to save it from GC
|
|
|
task.add_done_callback(background_tasks.discard)
|
|
|
yield runtime_pb2.ExpertResponse(tensors=output_tensors)
|