|
@@ -60,13 +60,14 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
self, request: runtime_pb2.ExpertRequest, context: P2PContext
|
|
self, request: runtime_pb2.ExpertRequest, context: P2PContext
|
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
|
+
|
|
future = self.experts[request.uid].forward_pool.submit_task(*inputs)
|
|
future = self.experts[request.uid].forward_pool.submit_task(*inputs)
|
|
serialized_response = [
|
|
serialized_response = [
|
|
serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
|
|
serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
|
|
for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))
|
|
for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))
|
|
]
|
|
]
|
|
|
|
|
|
- return as_aiter(runtime_pb2.ExpertResponse(tensors=serialized_response))
|
|
|
|
|
|
+ yield runtime_pb2.ExpertResponse(tensors=serialized_response)
|
|
|
|
|
|
async def rpc_backward(
|
|
async def rpc_backward(
|
|
self, request: runtime_pb2.ExpertRequest, context: P2PContext
|
|
self, request: runtime_pb2.ExpertRequest, context: P2PContext
|
|
@@ -77,4 +78,4 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
|
|
serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
|
|
for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))
|
|
for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))
|
|
]
|
|
]
|
|
- return as_aiter(runtime_pb2.ExpertResponse(tensors=serialized_response))
|
|
|
|
|
|
+ yield runtime_pb2.ExpertResponse(tensors=serialized_response)
|