|
@@ -71,6 +71,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
print("CLOSED RPC_INFERENCE")
|
|
|
|
|
|
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
+ return await super().rpc_forward(request, context)
|
|
|
# Parse request and prepare backends
|
|
|
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
requested_uids = self._check_header(request)
|
|
@@ -96,6 +97,8 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
async def rpc_forward_stream(
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
+ async for response in super().rpc_forward_stream(requests, context):
|
|
|
+ yield response
|
|
|
# Parse requests and prepare backends
|
|
|
uids_header, hidden_states = await self._gather_inputs(requests, context)
|
|
|
requested_uids = self._check_header_str(uids_header)
|
|
@@ -124,6 +127,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
|
|
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
+ return await super().rpc_backward(request, context)
|
|
|
# Parse requests and prepare backends
|
|
|
inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
requested_uids = self._check_header(request)
|
|
@@ -157,6 +161,8 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
async def rpc_backward_stream(
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
|
+ async for response in super().rpc_backward_stream(requests, context):
|
|
|
+ yield response
|
|
|
uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
|
|
|
inputs, grads = inputs_and_grads
|
|
|
requested_uids = self._check_header_str(uids_header)
|