|
@@ -1,7 +1,7 @@
|
|
|
import asyncio
|
|
|
import multiprocessing as mp
|
|
|
import pickle
|
|
|
-from typing import Dict
|
|
|
+from typing import AsyncIterator, Dict
|
|
|
|
|
|
import torch
|
|
|
|
|
@@ -56,7 +56,11 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
|
|
|
return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
|
|
|
|
|
|
- async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
+ async def rpc_forward(
|
|
|
+ self, stream: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
+ ) -> runtime_pb2.ExpertResponse:
|
|
|
+ request = await stream.__anext__()
|
|
|
+
|
|
|
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
future = self.experts[request.uid].forward_pool.submit_task(*inputs)
|
|
|
serialized_response = [
|
|
@@ -67,8 +71,10 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
return runtime_pb2.ExpertResponse(tensors=serialized_response)
|
|
|
|
|
|
async def rpc_backward(
|
|
|
- self, request: runtime_pb2.ExpertRequest, context: P2PContext
|
|
|
+ self, stream: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> runtime_pb2.ExpertResponse:
|
|
|
+ request = await stream.__anext__()
|
|
|
+
|
|
|
inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
|
|
|
serialized_response = [
|