Răsfoiți Sursa

fix review issues

Pavel Samygin 3 ani în urmă
părinte
comite
2fcf7368ff

+ 4 - 3
benchmarks/benchmark_throughput.py

@@ -8,7 +8,8 @@ import torch
 
 from hivemind.dht import DHT
 from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
-from hivemind.moe.server import ExpertBackend, Server, layers
+from hivemind.moe.server import ExpertBackend, Server
+from hivemind.moe.server.layers import name_to_block
 from hivemind.p2p import P2P, PeerInfo
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -82,7 +83,7 @@ def benchmark_throughput(
         or not torch.cuda.is_initialized()
         or torch.device(device) == torch.device("cpu")
     )
-    assert expert_cls in layers.name_to_block
+    assert expert_cls in name_to_block
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
     benchmarking_failed = mp.Event()
@@ -119,7 +120,7 @@ def benchmark_throughput(
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         experts = {}
         for i in range(num_experts):
-            expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
+            expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
             experts[f"expert.{i}"] = ExpertBackend(
                 name=f"expert.{i}",
                 expert=expert,

+ 5 - 5
hivemind/moe/client/expert.py

@@ -24,7 +24,7 @@ from hivemind.utils import (
     nested_pack,
     switch_to_uvloop,
 )
-from hivemind.utils.grpc import gather_from_rpc, split_for_streaming
+from hivemind.utils.grpc import gather_from_streaming, split_for_streaming
 from hivemind.utils.mpfuture import MPFuture
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
@@ -180,12 +180,12 @@ async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Te
 
     grad_inputs = await stub.rpc_backward_stream(
         amap_in_executor(
-            lambda t: runtime_pb2.ExpertRequest(uid=uid, tensors=[t]),
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
             as_aiter(*split),
         ),
     )
 
-    return await gather_from_rpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
+    return await gather_from_streaming(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
 
 
 async def _backward(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
@@ -216,12 +216,12 @@ async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Ten
 
     outputs = await stub.rpc_forward_stream(
         amap_in_executor(
-            lambda t: runtime_pb2.ExpertRequest(uid=uid, tensors=[t]),
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
             as_aiter(*split),
         ),
     )
 
-    return await gather_from_rpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
+    return await gather_from_streaming(outputs, lambda r: r.tensors, deserialize_torch_tensor)
 
 
 async def _forward(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:

+ 28 - 27
hivemind/moe/server/connection_handler.py

@@ -1,6 +1,6 @@
 import asyncio
 import multiprocessing as mp
-from typing import AsyncIterator, Dict, Iterable, List, Tuple, Union
+from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Union
 
 import torch
 
@@ -13,12 +13,28 @@ from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.grpc import gather_from_rpc, split_for_streaming
+from hivemind.utils.grpc import gather_from_streaming, split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 logger = get_logger(__name__)
 
 
+class _RequestUnpacker:
+
+    __slots__ = ("uid",)
+
+    def __init__(self):
+        self.uid: Optional[str] = None
+
+    def __call__(self, request: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
+        if self.uid is None:
+            self.uid = request.uid
+        else:
+            assert self.uid == request.uid, "Expert uids differ in one request"
+
+        return request.tensors
+
+
 class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
     """
     A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
@@ -59,26 +75,11 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
     async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
         return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
 
-    class _RequestUnpacker:
-
-        __slots__ = ("uid",)
-
-        def __init__(self):
-            self.uid = None
-
-        def __call__(self, request: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
-            if self.uid is None:
-                self.uid = request.uid
-            else:
-                assert self.uid == request.uid, "Expert uids differ in one request"
-
-            return request.tensors
-
     async def _gather_inputs(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> Tuple[str, List[torch.Tensor]]:
-        unpacker = self._RequestUnpacker()
-        inputs = await gather_from_rpc(requests, unpacker, deserialize_torch_tensor)
+        unpacker = _RequestUnpacker()
+        inputs = await gather_from_streaming(requests, unpacker, deserialize_torch_tensor)
         return unpacker.uid, inputs
 
     async def _process_inputs(
@@ -88,8 +89,8 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]],
     ) -> List[runtime_pb2.Tensor]:
         return [
-            serialize_torch_tensor(t, p.compression, allow_inplace=True)
-            for t, p in zip(await pool.submit_task(*inputs), nested_flatten(schema))
+            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            for result, proto in zip(await pool.submit_task(*inputs), nested_flatten(schema))
         ]
 
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
@@ -105,9 +106,9 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         uid, inputs = await self._gather_inputs(requests, context)
         expert = self.experts[uid]
         output_split = [
-            p
-            for t in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
-            for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)
+            part
+            for tensor in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
+            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE // 2)
         ]
 
         async for part in as_aiter(*output_split):
@@ -128,9 +129,9 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         uid, inputs_and_grads = await self._gather_inputs(requests, context)
         expert = self.experts[uid]
         output_split = [
-            p
-            for t in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
-            for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)
+            part
+            for tensor in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
+            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE // 2)
         ]
 
         async for part in as_aiter(*output_split):

+ 2 - 0
hivemind/p2p/p2p_daemon.py

@@ -455,6 +455,8 @@ class P2P:
                              (not just ``TInputProtobuf``) as input.
         :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
                               (not ``Awaitable[TOutputProtobuf]``).
+        :param balanced: If True, handler will be balanced on p2pd side between all handlers in python.
+                         Default: False
         """
 
         if not stream_input and not stream_output:

+ 1 - 0
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -110,6 +110,7 @@ class Client:
         Register a stream handler
         :param proto: protocols that handler serves
         :param handler_cb: handler callback
+        :param balanced: flag if stream handler should be balanced on p2pd side. Default: False.
         :return:
         """
         await self.control.stream_handler(proto=proto, handler_cb=handler_cb, balanced=balanced)

+ 4 - 4
hivemind/utils/grpc.py

@@ -224,12 +224,12 @@ def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.
     return serialized_tensor
 
 
-RpcMessage = TypeVar("RpcMessage")
+StreamMessage = TypeVar("StreamMessage")
 
 
-async def gather_from_rpc(
-    stream: AsyncIterator[RpcMessage],
-    key: Callable[[RpcMessage], Iterable[runtime_pb2.Tensor]],
+async def gather_from_streaming(
+    stream: AsyncIterator[StreamMessage],
+    key: Callable[[StreamMessage], Iterable[runtime_pb2.Tensor]],
     deserializer: Callable[[runtime_pb2.Tensor], torch.Tensor],
 ) -> List[torch.Tensor]:
     tensors = []