|
@@ -590,3 +590,44 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
result.update(block_info)
|
|
|
|
|
|
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
|
|
|
+
|
|
|
+ async def _wrap_input_stream(stream):
|
|
|
+ while True:
|
|
|
+ expert_request = await anext(stream)
|
|
|
+ yield expert_request
|
|
|
+ print(expert_request.metadata)
|
|
|
+ if expert_request.metadata.get("SEP"):
|
|
|
+ break
|
|
|
+
|
|
|
+ async def rpc_forward_backward_stream(
|
|
|
+ self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
+ ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
+ async with timeout(self.request_timeout):
|
|
|
+ wrapped_requests = self._wrap_input_stream(requests)
|
|
|
+
|
|
|
+ # Parse requests and prepare backends
|
|
|
+ uid_str, flat_inputs, metadata = await self._gather_inputs(wrapped_requests, context)
|
|
|
+ requested_uids = self._check_uids(uid_str)
|
|
|
+ self._log_request("rpc_forward_stream", requested_uids, context)
|
|
|
+
|
|
|
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
+ active_adapter = self._get_active_adapter(metadata)
|
|
|
+ points = metadata.get("points", 0)
|
|
|
+ args_structure = metadata.get("args_structure")
|
|
|
+ assert isinstance(
|
|
|
+ points, (float, int)
|
|
|
+ ), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
|
|
+
|
|
|
+ hidden_states = await run_rpc_forward(
|
|
|
+ *flat_inputs,
|
|
|
+ requested_backends=requested_backends,
|
|
|
+ prioritizer=self._prioritizer,
|
|
|
+ active_adapter=active_adapter,
|
|
|
+ points=points,
|
|
|
+ args_structure=args_structure,
|
|
|
+ )
|
|
|
+
|
|
|
+ for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
|
|
|
+ for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
|
|
+ print("EOS")
|
|
|
+ yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"EOS": True}))
|