Denis Mazur 1 年之前
父節點
當前提交
f8b922ca34
共有 1 個文件被更改,包括 41 次插入0 次删除
  1. 41 0
      src/petals/server/handler.py

+ 41 - 0
src/petals/server/handler.py

@@ -590,3 +590,44 @@ class TransformerConnectionHandler(ConnectionHandler):
             result.update(block_info)
 
         return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
+
+    async def _wrap_input_stream(stream):
+        while True:
+            expert_request = await anext(stream)
+            yield expert_request
+            print(expert_request.metadata)
+            if expert_request.metadata.get("SEP"):
+                break
+
+    async def rpc_forward_backward_stream(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+        async with timeout(self.request_timeout):
+            wrapped_requests = self._wrap_input_stream(requests)
+
+            # Parse requests and prepare backends
+            uid_str, flat_inputs, metadata = await self._gather_inputs(wrapped_requests, context)
+            requested_uids = self._check_uids(uid_str)
+            self._log_request("rpc_forward_stream", requested_uids, context)
+
+            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            active_adapter = self._get_active_adapter(metadata)
+            points = metadata.get("points", 0)
+            args_structure = metadata.get("args_structure")
+            assert isinstance(
+                points, (float, int)
+            ), f"rpc_forward_stream should have number of points as number or None, got {points}"
+
+            hidden_states = await run_rpc_forward(
+                *flat_inputs,
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
+                args_structure=args_structure,
+            )
+
+            for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
+                for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
+                    print("EOS")
+                    yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"EOS": True}))