瀏覽代碼

Forward step_metadata to fix p2p pushing in rpc_inference; found by @miaoqijun in #550

Co-authored-by: Yingtong Dou <ytongdou@gmail.com>
Yingtong Dou 1 年之前
父節點
當前提交
ac48d6670a
共有 2 個文件被更改,包括 4 次插入4 次删除
  1. 2 2
      src/petals/server/block_functions.py
  2. 2 2
      src/petals/server/handler.py

+ 2 - 2
src/petals/server/block_functions.py

@@ -153,7 +153,7 @@ async def iterate_rpc_inference(
     points: int,
     quant_type: QuantType,
     args_structure: Any = None,
-) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
+) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool, Dict]]:
     assert len(cache_handles) == len(requested_backends)
 
     prefix_length = 0
@@ -224,7 +224,7 @@ async def iterate_rpc_inference(
             for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
         ]
         can_push = not has_prompts
-        yield output_tensors, can_push
+        yield output_tensors, can_push, step_metadata
 
         # prepare for next step
         prefix_length += length_increment

+ 2 - 2
src/petals/server/handler.py

@@ -171,7 +171,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                     requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
                 ) as cache_handles:
                     background_tasks = set()
-                    async for output_tensors, can_push in iterate_rpc_inference(
+                    async for output_tensors, can_push, step_metadata in iterate_rpc_inference(
                         requested_uids=requested_uids,
                         requested_backends=requested_backends,
                         active_adapter=self._get_active_adapter(metadata),
@@ -186,7 +186,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                         args_structure=args_structure,
                     ):
                         if can_push:
-                            task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
+                            task = asyncio.create_task(self._push_outputs(request, output_tensors[0], step_metadata))
                             background_tasks.add(task)  # Keep reference until it is done to save it from GC
                             task.add_done_callback(background_tasks.discard)
                         yield runtime_pb2.ExpertResponse(tensors=output_tensors)