Browse Source

remove usless Unpacker class, some style refactoring

Pavel Samygin 3 years ago
parent
commit
aad0d4db64

+ 7 - 8
hivemind/moe/client/expert.py

@@ -11,7 +11,7 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
-import hivemind
+from hivemind import moe
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p import P2P, PeerInfo, StubBase
@@ -32,8 +32,8 @@ from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
 
 
-def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo):  # -> ConnectionHandlerStub:
-    return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
+def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo) -> "ConnectionHandlerStub":
+    return moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
 
 
 
 
 @dataclass(frozen=True)
 @dataclass(frozen=True)
@@ -251,13 +251,12 @@ async def expert_forward(uid: str, inputs: Sequence[torch.Tensor], compressions:
 class _RemoteModuleCall(torch.autograd.Function):
 class _RemoteModuleCall(torch.autograd.Function):
     """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
     """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
 
 
-    @classmethod
+    @staticmethod
     def forward(
     def forward(
-        cls,
         ctx,
         ctx,
         dummy: torch.Tensor,
         dummy: torch.Tensor,
         uid: str,
         uid: str,
-        stub,  #: ConnectionHandlerStub,
+        stub: "ConnectionHandlerStub",
         info: Dict[str, Any],
         info: Dict[str, Any],
         *inputs: torch.Tensor,
         *inputs: torch.Tensor,
     ) -> Tuple[torch.Tensor, ...]:
     ) -> Tuple[torch.Tensor, ...]:
@@ -273,9 +272,9 @@ class _RemoteModuleCall(torch.autograd.Function):
 
 
         return tuple(deserialized_outputs)
         return tuple(deserialized_outputs)
 
 
-    @classmethod
+    @staticmethod
     @once_differentiable
     @once_differentiable
-    def backward(cls, ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
+    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))

+ 1 - 1
hivemind/moe/client/moe.py

@@ -183,7 +183,7 @@ class RemoteMixtureOfExperts(nn.Module):
         if self._expert_info is None:
         if self._expert_info is None:
             # grab some expert to set ensemble output shape
             # grab some expert to set ensemble output shape
             proj_device = self.proj.weight.device
             proj_device = self.proj.weight.device
-            dummy_scores_concat: torch.Tensor = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
+            dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
             dummy_scores = dummy_scores_concat.cpu().detach().split_with_sizes(self.beam_search.grid_size, dim=-1)
             dummy_scores = dummy_scores_concat.cpu().detach().split_with_sizes(self.beam_search.grid_size, dim=-1)
             dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
             dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
             self._expert_info = dummy_experts[0].info
             self._expert_info = dummy_experts[0].info

+ 14 - 19
hivemind/moe/server/connection_handler.py

@@ -19,22 +19,6 @@ from hivemind.utils.tensor_descr import BatchTensorDescriptor
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class _RequestUnpacker:
-
-    __slots__ = ("uid",)
-
-    def __init__(self):
-        self.uid: Optional[str] = None
-
-    def __call__(self, request: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
-        if self.uid is None:
-            self.uid = request.uid
-        else:
-            assert self.uid == request.uid, "Expert uids differ in one request"
-
-        return request.tensors
-
-
 class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
 class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
     """
     """
     A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
     A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
@@ -78,9 +62,20 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
     async def _gather_inputs(
     async def _gather_inputs(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> Tuple[str, List[torch.Tensor]]:
     ) -> Tuple[str, List[torch.Tensor]]:
-        unpacker = _RequestUnpacker()
-        inputs = await gather_from_streaming(requests, unpacker, deserialize_torch_tensor)
-        return unpacker.uid, inputs
+        expert_uid = None
+
+        def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
+            nonlocal expert_uid
+
+            if expert_uid is None:
+                expert_uid = req.uid
+            elif expert_uid != req.uid:
+                raise ValueError("Expert uids differ in one reques")
+
+            return req.tensors
+
+        inputs = await gather_from_streaming(requests, _unpack, deserialize_torch_tensor)
+        return expert_uid, inputs
 
 
     async def _process_inputs(
     async def _process_inputs(
         self,
         self,

+ 1 - 1
hivemind/moe/server/server.py

@@ -41,7 +41,7 @@ class Server(threading.Thread):
      - processes incoming forward/backward requests via Runtime (created by the server)
      - processes incoming forward/backward requests via Runtime (created by the server)
      - publishes updates to expert status every :update_period: seconds
      - publishes updates to expert status every :update_period: seconds
 
 
-    :type dht: DHT.
+    :type dht: an instance of hivemind.DHT.
     :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
     :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
     :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
     :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
     :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
     :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1

+ 3 - 3
tests/test_p2p_daemon_bindings.py

@@ -560,7 +560,7 @@ async def test_client_stream_handler_success(p2pcs):
 
 
     writer.close()
     writer.close()
 
 
-    # test case: registering twice can override the previous registration
+    # test case: registering twice can not override the previous registration without balanced flag
     event_third = asyncio.Event()
     event_third = asyncio.Event()
 
 
     async def handler_third(stream_info, reader, writer):
     async def handler_third(stream_info, reader, writer):
@@ -570,8 +570,8 @@ async def test_client_stream_handler_success(p2pcs):
     with pytest.raises(ControlFailure):
     with pytest.raises(ControlFailure):
         await p2pcs[1].stream_handler(another_proto, handler_third)
         await p2pcs[1].stream_handler(another_proto, handler_third)
 
 
-    # add in balanced mode, know handler should be placed in round robin queue
-    # also it should be next to be called
+    # add in balanced mode: handler should be placed in round robin queue
+    # and become the next to be called
     await p2pcs[1].stream_handler(another_proto, handler_third, True)
     await p2pcs[1].stream_handler(another_proto, handler_third, True)
     assert another_proto in p2pcs[1].control.handlers
     assert another_proto in p2pcs[1].control.handlers
     # ensure the handler is override
     # ensure the handler is override