|
@@ -25,12 +25,12 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
|
|
|
:note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port
|
|
|
:param dht: a running hivemind.dht.DHT, used to let other peers connect to this one
|
|
|
- :param experts: a dict [UID -> ModuleBackend] with all active experts
|
|
|
+ :param module_backends: a dict [UID -> ModuleBackend] with all active experts
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, dht: DHT, experts: Dict[str, ModuleBackend]):
|
|
|
+ def __init__(self, dht: DHT, module_backends: Dict[str, ModuleBackend]):
|
|
|
super().__init__()
|
|
|
- self.dht, self.experts = dht, experts
|
|
|
+ self.dht, self.module_backends = dht, module_backends
|
|
|
self._p2p: Optional[P2P] = None
|
|
|
|
|
|
self.ready = MPFuture()
|
|
@@ -59,7 +59,8 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
logger.debug("Caught KeyboardInterrupt, shutting down")
|
|
|
|
|
|
async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
|
|
|
- return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
|
|
|
+ module_info = self.module_backends[request.uid].get_info()
|
|
|
+ return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(module_info))
|
|
|
|
|
|
async def _gather_inputs(
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
@@ -93,7 +94,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
|
|
|
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
- expert = self.experts[request.uid]
|
|
|
+ expert = self.module_backends[request.uid]
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
tensors=await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
|
|
|
)
|
|
@@ -102,7 +103,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
uid, inputs = await self._gather_inputs(requests, context)
|
|
|
- expert = self.experts[uid]
|
|
|
+ expert = self.module_backends[uid]
|
|
|
output_split = [
|
|
|
part
|
|
|
for tensor in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
|
|
@@ -116,7 +117,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
self, request: runtime_pb2.ExpertRequest, context: P2PContext
|
|
|
) -> runtime_pb2.ExpertResponse:
|
|
|
inputs_and_grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
- expert = self.experts[request.uid]
|
|
|
+ expert = self.module_backends[request.uid]
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
tensors=await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
|
|
|
)
|
|
@@ -125,7 +126,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
|
uid, inputs_and_grads = await self._gather_inputs(requests, context)
|
|
|
- expert = self.experts[uid]
|
|
|
+ expert = self.module_backends[uid]
|
|
|
output_split = [
|
|
|
part
|
|
|
for tensor in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
|