Răsfoiți Sursa

implement streaming for moe-expert passes

Pavel Samygin 3 ani în urmă
părinte
comite
5c06184a19

+ 1 - 1
benchmarks/benchmark_throughput_p2p.py

@@ -250,7 +250,7 @@ if __name__ == "__main__":
             num_clients=1,
             num_handlers=1,
             num_batches_per_client=args.num_batches_per_client,
-            batch_size=256,
+            batch_size=1024,
         )
     elif args.preset == "nop":
         benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)

+ 25 - 5
hivemind/moe/client/expert.py

@@ -11,8 +11,10 @@ from torch.autograd.function import once_differentiable
 import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, PeerInfo, StubBase
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
-from hivemind.utils import MSGPackSerializer, asingle, nested_compare, nested_flatten, nested_pack, switch_to_uvloop
+from hivemind.utils import MSGPackSerializer, amap_in_executor, as_aiter, nested_compare, nested_flatten, nested_pack, switch_to_uvloop
+from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
@@ -133,11 +135,20 @@ class _RemoteModuleCall(torch.autograd.Function):
             for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
         ]
 
+        split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
+
         outputs = cls.run_coroutine(
-            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
+            stub.rpc_forward(
+                amap_in_executor(
+                    lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t, ]),
+                    as_aiter(*split)
+                ),
+            )
         )
 
-        deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
+        deserialized_outputs = cls.run_coroutine(
+            gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
+        )
 
         return tuple(deserialized_outputs)
 
@@ -152,9 +163,18 @@ class _RemoteModuleCall(torch.autograd.Function):
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
         ]
 
+        split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
+
         grad_inputs = cls.run_coroutine(
-            ctx.stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
+            ctx.stub.rpc_backward(
+                amap_in_executor(
+                    lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t, ]),
+                    as_aiter(*split)
+                ),
+            )
         )
 
-        deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
+        deserialized_grad_inputs = cls.run_coroutine(
+            gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
+        )
         return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 54 - 18
hivemind/moe/server/connection_handler.py

@@ -1,16 +1,20 @@
 import asyncio
 import multiprocessing as mp
-from typing import AsyncIterator, Dict
+from typing import AsyncIterator, Dict, Iterable, Union
 
 import torch
 
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.task_pool import TaskPool
 from hivemind.p2p import P2PContext, ServicerBase
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.utils import MSGPackSerializer, MPFuture, as_aiter, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
+from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 logger = get_logger(__name__)
 
@@ -55,26 +59,58 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
     async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
         return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
 
+    class _RequestUnpacker:
+
+        __slots__ = "uid",
+
+        def __init__(self):
+            self.uid = None
+
+        def __call__(self, request: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
+            if self.uid is None:
+                self.uid = request.uid
+            else:
+                assert self.uid == request.uid, "Expert uids differ in one request"
+
+            return request.tensors
+
+    async def _gather_inputs(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> tuple[str, list[torch.Tensor]]:
+        unpacker = self._RequestUnpacker()
+        inputs = await gather_from_grpc(requests, unpacker, deserialize_torch_tensor)
+        return unpacker.uid, inputs
+
+    async def _process_inputs(
+        self, inputs: list[torch.Tensor], pool: TaskPool, schema: Union[BatchTensorDescriptor, tuple[BatchTensorDescriptor, ...]]
+    ):
+        return [
+            serialize_torch_tensor(t, p.compression, allow_inplace=True)
+            for t, p in zip(await pool.submit_task(*inputs), nested_flatten(schema))
+        ]
+
     async def rpc_forward(
-        self, request: runtime_pb2.ExpertRequest, context: P2PContext
-    ) -> runtime_pb2.ExpertResponse:
-        inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-
-        future = self.experts[request.uid].forward_pool.submit_task(*inputs)
-        serialized_response = [
-            serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
-            for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+        uid, inputs = await self._gather_inputs(requests, context)
+        expert = self.experts[uid]
+        output_split = [
+            p for t in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
+            for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)
         ]
 
-        return runtime_pb2.ExpertResponse(tensors=serialized_response)
+        async for part in as_aiter(*output_split):
+            yield runtime_pb2.ExpertResponse(tensors=[part, ])
 
     async def rpc_backward(
-        self, request: runtime_pb2.ExpertRequest, context: P2PContext
-    ) -> runtime_pb2.ExpertResponse:
-        inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
-        serialized_response = [
-            serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
-            for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
+        uid, inputs_and_grads = await self._gather_inputs(requests, context)
+        expert = self.experts[uid]
+        output_split = [
+            p for t in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
+            for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)
         ]
-        return runtime_pb2.ExpertResponse(tensors=serialized_response)
+
+        async for part in as_aiter(*output_split):
+            yield runtime_pb2.ExpertResponse(tensors=[part, ])

+ 26 - 1
hivemind/utils/grpc.py

@@ -6,7 +6,8 @@ from __future__ import annotations
 
 import os
 import threading
-from typing import Any, Dict, Iterable, Iterator, NamedTuple, Optional, Tuple, Type, TypeVar, Union
+import torch
+from typing import Callable, AsyncIterator, Any, Dict, Iterable, Iterator, NamedTuple, Optional, Tuple, Type, TypeVar, Union
 
 import grpc
 
@@ -208,3 +209,27 @@ def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.
         buffer_chunks.append(tensor_part.buffer)
     serialized_tensor.buffer = b"".join(buffer_chunks)
     return serialized_tensor
+
+
+RpcMessage = TypeVar("RpcMessage")
+
+async def gather_from_grpc(
+    stream: AsyncIterator[RpcMessage],
+    key: Callable[[RpcMessage], Iterable[runtime_pb2.Tensor]],
+    deserializer: Callable[[runtime_pb2.Tensor], torch.Tensor],
+) -> list[torch.Tensor]:
+    tensors = []
+    parts = []
+
+    async for msg in stream:
+        parts_stream = key(msg)
+        for part in parts_stream:
+            if part.dtype and parts:
+                tensors.append(deserializer(combine_from_streaming(parts)))
+                parts = []
+
+            parts.append(part)
+    if parts:
+        tensors.append(deserializer(combine_from_streaming(parts)))
+
+    return tensors