فهرست منبع

refactor moe package

Denis Mazur 4 سال پیش
والد
کامیت
bddf730480
4فایلهای تغییر یافته به همراه315 افزوده شده و 163 حذف شده
  1. 236 0
      benchmarks/benchmark_throughput-p2p.py
  2. 1 1
      hivemind/moe/__init__.py
  3. 78 22
      hivemind/moe/client/expert.py
  4. 0 140
      hivemind/moe/expert.py

+ 236 - 0
benchmarks/benchmark_throughput-p2p.py

@@ -0,0 +1,236 @@
+import argparse
+import multiprocessing as mp
+import random
+import sys
+import time
+
+import torch
+
+import hivemind
+from hivemind import get_free_port
+from hivemind.moe.server import layers
+from hivemind.utils.limits import increase_file_limit
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+def print_device_info(device=None):
+    """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
+    device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
+    logger.info(f"Using device: {device}")
+
+    # Additional Info when using cuda
+    if device.type == "cuda":
+        logger.info(torch.cuda.get_device_name(0))
+        logger.info(f"Memory Usage:")
+        logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
+        logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
+
+
+def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
+    torch.set_num_threads(1)
+    can_start.wait()
+    experts = [
+        hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)
+    ]
+
+    try:
+        dummy_batch = torch.randn(batch_size, hid_dim)
+        for batch_i in range(num_batches):
+            expert = random.choice(experts)
+            out = expert(dummy_batch)
+            if backprop:
+                out.sum().backward()
+    except BaseException as e:
+        benchmarking_failed.set()
+        raise e
+
+
+def benchmark_throughput(
+    num_experts=16,
+    num_handlers=None,
+    num_clients=128,
+    num_batches_per_client=16,
+    expert_cls="ffn",
+    hid_dim=1024,
+    batch_size=2048,
+    max_batch_size=None,
+    backprop=True,
+    device=None,
+    port=None,
+):
+    assert (
+        not hasattr(torch.cuda, "is_initialized")
+        or not torch.cuda.is_initialized()
+        or torch.device(device) == torch.device("cpu")
+    )
+    assert expert_cls in layers.name_to_block
+    port = port or get_free_port()
+    max_batch_size = max_batch_size or batch_size * 4
+    num_handlers = max(1, num_handlers or num_clients // 2)
+    benchmarking_failed = mp.Event()
+    can_start = mp.Event()
+    timestamps = dict(started=time.perf_counter())
+
+    try:
+        # start clients and await server
+        # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
+        clients = [
+            mp.Process(
+                target=client_process,
+                name=f"client_process-{i}",
+                args=(
+                    can_start,
+                    benchmarking_failed,
+                    port,
+                    num_experts,
+                    batch_size,
+                    hid_dim,
+                    num_batches_per_client,
+                    backprop,
+                ),
+            )
+            for i in range(num_clients)
+        ]
+
+        for client in clients:
+            client.daemon = True
+            client.start()
+
+        timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
+
+        # start server
+        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+        experts = {}
+        for i in range(num_experts):
+            expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
+            experts[f"expert{i}"] = hivemind.ExpertBackend(
+                name=f"expert{i}",
+                expert=expert,
+                optimizer=torch.optim.Adam(expert.parameters()),
+                args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
+                outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
+                max_batch_size=max_batch_size,
+            )
+        timestamps["created_experts"] = time.perf_counter()
+        server = hivemind.moe.Server(
+            None,
+            experts,
+            listen_on=f"{hivemind.LOCALHOST}:{port}",
+            num_connection_handlers=num_handlers,
+            device=device,
+        )
+        server.start()
+        server.ready.wait()
+        timestamps["server_ready"] = time.perf_counter()
+        can_start.set()
+
+        for client in clients:
+            client.join()
+        timestamps["clients_finished"] = time.perf_counter()
+    except BaseException as e:
+        benchmarking_failed.set()
+        raise e
+    finally:
+        for client in clients:
+            if client.is_alive():
+                client.terminate()
+        server.shutdown()
+        timestamps["server_shutdown_finished"] = time.perf_counter()
+        server.join()
+
+    sys.stdout.flush()
+    sys.stderr.flush()
+    time_between = (
+        lambda key1, key2: abs(timestamps[key2] - timestamps[key1])
+        if (key1 in timestamps and key2 in timestamps)
+        else float("nan")
+    )
+    total_examples = batch_size * num_clients * num_batches_per_client
+
+    logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
+    logger.info(
+        f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, "
+        f"max_batch_size={max_batch_size}, expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}"
+    )
+    logger.info(
+        f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
+        f"batch_size={batch_size}, backprop={backprop}"
+    )
+
+    logger.info("Results: ")
+    logger.info(
+        f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
+        f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
+        f"{time_between('created_experts', 'server_ready') :.3f} s. networking)"
+    )
+    logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
+    logger.info(
+        f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
+        f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s."
+    )
+    logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
+    if benchmarking_failed.is_set():
+        logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
+    print_device_info(device)
+    sys.stdout.flush()
+    sys.stderr.flush()
+
+    assert not benchmarking_failed.is_set()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--preset", type=str, default="default", required=False)
+    parser.add_argument("--num_batches_per_client", type=int, default=16, required=False)
+    args = parser.parse_args()
+
+    if args.preset in ("default", "ffn_forward_backward"):
+        benchmark_throughput()
+    elif args.preset == "ffn_forward":
+        benchmark_throughput(backprop=False, num_batches_per_client=args.num_batches_per_client)
+    elif args.preset == "ffn_small_batch":
+        benchmark_throughput(
+            backprop=False,
+            num_experts=4,
+            batch_size=32,
+            max_batch_size=8192,
+            num_batches_per_client=args.num_batches_per_client,
+        )
+    elif args.preset == "ffn_small_batch_512clients":
+        benchmark_throughput(
+            backprop=True,
+            num_experts=1,
+            batch_size=1,
+            max_batch_size=8192,
+            num_clients=512,
+            num_batches_per_client=args.num_batches_per_client,
+        )
+    elif args.preset == "ffn_small_batch_512clients_32handlers":
+        benchmark_throughput(
+            backprop=True,
+            num_experts=1,
+            batch_size=1,
+            max_batch_size=8192,
+            num_handlers=32,
+            num_clients=512,
+            num_batches_per_client=args.num_batches_per_client,
+        )
+    elif args.preset == "ffn_massive":
+        increase_file_limit()
+        benchmark_throughput(
+            backprop=False,
+            num_clients=512,
+            batch_size=512,
+            max_batch_size=8192,
+            num_batches_per_client=args.num_batches_per_client,
+        )
+    elif args.preset == "minimalistic":
+        benchmark_throughput(
+            num_experts=1, num_clients=1, num_handlers=1, num_batches_per_client=args.num_batches_per_client
+        )
+    elif args.preset == "nop":
+        benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)
+    else:
+        raise ValueError(f"No such benchmark preset: {args.preset}")

+ 1 - 1
hivemind/moe/__init__.py

@@ -1,2 +1,2 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class
+from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class, ConnectionHandler

+ 78 - 22
hivemind/moe/client/expert.py

@@ -1,22 +1,33 @@
+import asyncio
+from concurrent.futures import Future
+import multiprocessing as mp
 import pickle
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Dict, Optional, Tuple, Awaitable
+from threading import Thread
+from queue import Queue
 
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
+
+#from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.utils.grpc import ChannelCache
+from hivemind.utils import nested_compare, nested_flatten, nested_pack, switch_to_uvloop
+from hivemind.p2p import P2P, PeerInfo, StubBase
+from hivemind.proto import runtime_pb2
+
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
-def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]):
-    """Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache"""
-    channel_options = (("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)) + extra_options
-    return ChannelCache.get_stub(endpoint, runtime_grpc.ConnectionHandlerStub, aio=False, options=channel_options)
+import hivemind
+
+#ConnectionHandlerStub = hivemind.moe.server.connection_handler.ConnectionHandler._stub_type
+
+
+def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo): # -> ConnectionHandlerStub:
+    return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
 
 
 class RemoteExpert(nn.Module):
@@ -26,17 +37,23 @@ class RemoteExpert(nn.Module):
     Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
     :param uid: unique expert identifier
-    :param endpoint: network endpoint of a server that services that expert, e.g. "201.123.321.99:1337" or "[::]:8080"
     """
 
-    def __init__(self, uid, endpoint: Endpoint):
+    def __init__(self, uid, server_peer_info: PeerInfo, p2p: Optional[P2P] = None):
         super().__init__()
-        self.uid, self.endpoint = uid, endpoint
+        self.uid, self.server_peer_info = uid, server_peer_info
         self._info = None
 
+        if p2p is None:
+            self.p2p = _RemoteModuleCall.run_coroutine(P2P.create())
+        else:
+            self.p2p = p2p
+
+        _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
+
     @property
-    def stub(self):
-        return _get_expert_stub(self.endpoint)
+    def stub(self) -> StubBase:
+        return _get_expert_stub(self.p2p, self.server_peer_info)
 
     def forward(self, *args, **kwargs):
         """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
@@ -49,31 +66,66 @@ class RemoteExpert(nn.Module):
 
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
-
         flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
+
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
 
     @property
     def info(self):
         if self._info is None:
-            outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
+            outputs = _RemoteModuleCall.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
             self._info = pickle.loads(outputs.serialized_info)
         return self._info
 
     def extra_repr(self):
-        return f"uid={self.uid}, endpoint={self.endpoint}"
+        return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
 
 
 class _RemoteModuleCall(torch.autograd.Function):
     """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
 
-    @staticmethod
+    _task_queue: Queue = Queue()
+    _event_thread: Optional[Thread] = None
+
+    @classmethod
+    def _run(cls):
+        loop = switch_to_uvloop()
+
+        async def receive_tasks():
+            while True:
+                cor, future = cls._task_queue.get()
+                try:
+                    result = await cor
+                except Exception as e:
+                    future.set_exception(e)
+                    continue
+
+                future.set_result(result)
+
+        loop.run_until_complete(receive_tasks())
+
+    @classmethod
+    def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
+        if cls._event_thread is None:
+            cls._event_thread = Thread(target=cls._run)
+            cls._event_thread.start()
+
+        future = Future()
+        cls._task_queue.put((coro, future))
+
+        if return_future:
+            return future
+
+        return future.result()
+
+    @classmethod
     def forward(
+        cls,
         ctx,
         dummy: torch.Tensor,
         uid: str,
-        stub: runtime_grpc.ConnectionHandlerStub,
+        stub,#: ConnectionHandlerStub,
         info: Dict[str, Any],
         *inputs: torch.Tensor,
     ) -> Tuple[torch.Tensor, ...]:
@@ -88,15 +140,17 @@ class _RemoteModuleCall(torch.autograd.Function):
             for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
         ]
 
-        outputs = stub.forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+        outputs = cls.run_coroutine(
+            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
+        )
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
 
         return tuple(deserialized_outputs)
 
-    @staticmethod
+    @classmethod
     @once_differentiable
-    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
+    def backward(cls, ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
@@ -105,7 +159,9 @@ class _RemoteModuleCall(torch.autograd.Function):
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
         ]
 
-        grad_inputs = ctx.stub.backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+        grad_inputs = cls.run_coroutine(
+            ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
+        )
 
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
         return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 0 - 140
hivemind/moe/expert.py

@@ -1,140 +0,0 @@
-import asyncio
-import multiprocessing as mp
-import pickle
-from typing import Any, Dict, Optional, Tuple
-from threading import Thread
-
-import torch
-import torch.nn as nn
-from torch.autograd.function import once_differentiable
-
-
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.utils import nested_compare, nested_flatten, nested_pack, switch_to_uvloop
-from hivemind.p2p import P2P, PeerInfo, StubBase
-from hivemind.proto import runtime_pb2
-from hivemind.moe.server import ConnectionHandler
-
-
-DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
-
-
-ConnectionHandlerStub = ConnectionHandler._stub_type
-
-
-def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo) -> ConnectionHandlerStub:
-    return ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
-
-
-class RemoteExpert(nn.Module):
-    """
-    A simple module that runs forward/backward of an expert hosted on a remote machine.
-    Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
-    Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
-    Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
-    :param uid: unique expert identifier
-    """
-
-    def __init__(self, uid, p2p: P2P, server_peer_info: PeerInfo):
-        super().__init__()
-        self.uid, self.p2p, self.server_peer_info = uid, p2p, server_peer_info
-        self._info = None
-
-        def _run():
-            self.loop = switch_to_uvloop()
-            self.loop.run_forever()
-
-        Thread(target=_run).start()
-
-    @property
-    def stub(self) -> StubBase:
-        return _get_expert_stub(self.p2p, self.server_peer_info)
-
-    def forward(self, *args, **kwargs):
-        """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
-        assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
-        kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
-
-        # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
-
-        forward_inputs = (args, kwargs)
-
-        if not nested_compare(forward_inputs, self.info["forward_schema"]):
-            raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
-        flat_outputs = _RemoteModuleCall.apply(
-            DUMMY,
-            self.uid,
-            self.stub,
-            self.loop,
-            self.info,
-            *nested_flatten(forward_inputs),
-        )
-
-        # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
-        return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
-
-    @property
-    def info(self):
-        if self._info is None:
-            outputs = asyncio.run_coroutine_threadsafe(
-                self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)),
-                self.loop
-            ).result()
-            self._info = pickle.loads(outputs.serialized_info)
-        return self._info
-
-    def extra_repr(self):
-        return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
-
-
-class _RemoteModuleCall(torch.autograd.Function):
-    """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
-
-    @staticmethod
-    def forward(
-        ctx,
-        dummy: torch.Tensor,
-        uid: str,
-        stub: ConnectionHandlerStub,
-        loop: asyncio.AbstractEventLoop,
-        info: Dict[str, Any],
-        *inputs: torch.Tensor,
-    ) -> Tuple[torch.Tensor, ...]:
-        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
-        # detach to avoid pickling the computation graph
-        inputs = tuple(tensor.cpu().detach() for tensor in inputs)
-        ctx.uid, ctx.stub, ctx.info, ctx.loop = uid, stub, info, loop
-        ctx.save_for_backward(*inputs)
-
-        serialized_tensors = [
-            serialize_torch_tensor(inp, proto.compression)
-            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
-        ]
-
-        outputs = asyncio.run_coroutine_threadsafe(
-            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
-            loop,
-        ).result()
-
-        deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
-
-        return tuple(deserialized_outputs)
-
-    @staticmethod
-    @once_differentiable
-    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
-        grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
-        inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
-        backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
-        serialized_tensors = [
-            serialize_torch_tensor(tensor, proto.compression)
-            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-        ]
-
-        grad_inputs = asyncio.run_coroutine_threadsafe(
-            ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
-            ctx.loop,
-        ).result()
-
-        deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
-        return (DUMMY, None, None, None, *deserialized_grad_inputs)