Bläddra i källkod

finish renaming experts to module_backends in ConnectionHandler (#487)

Аinish renaming experts -> module_backends in ConnectionHandler
justheuristic 3 år sedan
förälder
incheckning
f60e34aec1
1 ändrade filer med 9 tillägg och 8 borttagningar
  1. 9 8
      hivemind/moe/server/connection_handler.py

+ 9 - 8
hivemind/moe/server/connection_handler.py

@@ -25,12 +25,12 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
 
     :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port
     :param dht: a running hivemind.dht.DHT, used to let other peers connect to this one
-    :param experts: a dict [UID -> ModuleBackend] with all active experts
+    :param module_backends: a dict [UID -> ModuleBackend] with all active experts
     """
 
-    def __init__(self, dht: DHT, experts: Dict[str, ModuleBackend]):
+    def __init__(self, dht: DHT, module_backends: Dict[str, ModuleBackend]):
         super().__init__()
-        self.dht, self.experts = dht, experts
+        self.dht, self.module_backends = dht, module_backends
         self._p2p: Optional[P2P] = None
 
         self.ready = MPFuture()
@@ -59,7 +59,8 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             logger.debug("Caught KeyboardInterrupt, shutting down")
 
     async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
-        return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
+        module_info = self.module_backends[request.uid].get_info()
+        return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(module_info))
 
     async def _gather_inputs(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@@ -93,7 +94,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
 
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        expert = self.experts[request.uid]
+        expert = self.module_backends[request.uid]
         return runtime_pb2.ExpertResponse(
             tensors=await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
         )
@@ -102,7 +103,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         uid, inputs = await self._gather_inputs(requests, context)
-        expert = self.experts[uid]
+        expert = self.module_backends[uid]
         output_split = [
             part
             for tensor in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
@@ -116,7 +117,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         self, request: runtime_pb2.ExpertRequest, context: P2PContext
     ) -> runtime_pb2.ExpertResponse:
         inputs_and_grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        expert = self.experts[request.uid]
+        expert = self.module_backends[request.uid]
         return runtime_pb2.ExpertResponse(
             tensors=await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
         )
@@ -125,7 +126,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
         uid, inputs_and_grads = await self._gather_inputs(requests, context)
-        expert = self.experts[uid]
+        expert = self.module_backends[uid]
         output_split = [
             part
             for tensor in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)