|
@@ -10,7 +10,7 @@ from hivemind.dht import DHT
|
|
|
from hivemind.moe.server.expert_backend import ExpertBackend
|
|
|
from hivemind.p2p import P2PContext, ServicerBase
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
-from hivemind.utils import MPFuture, asingle, get_logger, nested_flatten
|
|
|
+from hivemind.utils import MPFuture, as_aiter, get_logger, nested_flatten
|
|
|
from hivemind.utils.asyncio import switch_to_uvloop
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
@@ -57,10 +57,8 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
|
|
|
|
|
|
async def rpc_forward(
|
|
|
- self, stream: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
- ) -> runtime_pb2.ExpertResponse:
|
|
|
- request = await asingle(stream)
|
|
|
-
|
|
|
+ self, request: runtime_pb2.ExpertRequest, context: P2PContext
|
|
|
+ ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
|
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
future = self.experts[request.uid].forward_pool.submit_task(*inputs)
|
|
|
serialized_response = [
|
|
@@ -68,17 +66,15 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))
|
|
|
]
|
|
|
|
|
|
- return runtime_pb2.ExpertResponse(tensors=serialized_response)
|
|
|
+ return as_aiter(runtime_pb2.ExpertResponse(tensors=serialized_response))
|
|
|
|
|
|
async def rpc_backward(
|
|
|
- self, stream: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
- ) -> runtime_pb2.ExpertResponse:
|
|
|
- request = await asingle(stream)
|
|
|
-
|
|
|
+ self, request: runtime_pb2.ExpertRequest, context: P2PContext
|
|
|
+ ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
|
inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
|
|
|
serialized_response = [
|
|
|
serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
|
|
|
for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))
|
|
|
]
|
|
|
- return runtime_pb2.ExpertResponse(tensors=serialized_response)
|
|
|
+ return as_aiter(runtime_pb2.ExpertResponse(tensors=serialized_response))
|