瀏覽代碼

refactor moe package

Denis Mazur 4 年之前
父節點
當前提交
bddf730480
共有 4 個文件被更改,包括 315 次插入163 次删除
  1. 236 0
      benchmarks/benchmark_throughput-p2p.py
  2. 1 1
      hivemind/moe/__init__.py
  3. 78 22
      hivemind/moe/client/expert.py
  4. 0 140
      hivemind/moe/expert.py

+ 236 - 0
benchmarks/benchmark_throughput-p2p.py

@@ -0,0 +1,236 @@
+import argparse
+import multiprocessing as mp
+import random
+import sys
+import time
+
+import torch
+
+import hivemind
+from hivemind import get_free_port
+from hivemind.moe.server import layers
+from hivemind.utils.limits import increase_file_limit
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+def print_device_info(device=None):
+    """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
+    device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
+    logger.info(f"Using device: {device}")
+
+    # Additional Info when using cuda
+    if device.type == "cuda":
+        logger.info(torch.cuda.get_device_name(0))
+        logger.info(f"Memory Usage:")
+        logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
+        logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
+
+
+def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
+    torch.set_num_threads(1)
+    can_start.wait()
+    experts = [
+        hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)
+    ]
+
+    try:
+        dummy_batch = torch.randn(batch_size, hid_dim)
+        for batch_i in range(num_batches):
+            expert = random.choice(experts)
+            out = expert(dummy_batch)
+            if backprop:
+                out.sum().backward()
+    except BaseException as e:
+        benchmarking_failed.set()
+        raise e
+
+
+def benchmark_throughput(
+    num_experts=16,
+    num_handlers=None,
+    num_clients=128,
+    num_batches_per_client=16,
+    expert_cls="ffn",
+    hid_dim=1024,
+    batch_size=2048,
+    max_batch_size=None,
+    backprop=True,
+    device=None,
+    port=None,
+):
+    assert (
+        not hasattr(torch.cuda, "is_initialized")
+        or not torch.cuda.is_initialized()
+        or torch.device(device) == torch.device("cpu")
+    )
+    assert expert_cls in layers.name_to_block
+    port = port or get_free_port()
+    max_batch_size = max_batch_size or batch_size * 4
+    num_handlers = max(1, num_handlers or num_clients // 2)
+    benchmarking_failed = mp.Event()
+    can_start = mp.Event()
+    timestamps = dict(started=time.perf_counter())
+
+    try:
+        # start clients and await server
+        # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
+        clients = [
+            mp.Process(
+                target=client_process,
+                name=f"client_process-{i}",
+                args=(
+                    can_start,
+                    benchmarking_failed,
+                    port,
+                    num_experts,
+                    batch_size,
+                    hid_dim,
+                    num_batches_per_client,
+                    backprop,
+                ),
+            )
+            for i in range(num_clients)
+        ]
+
+        for client in clients:
+            client.daemon = True
+            client.start()
+
+        timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
+
+        # start server
+        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+        experts = {}
+        for i in range(num_experts):
+            expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
+            experts[f"expert{i}"] = hivemind.ExpertBackend(
+                name=f"expert{i}",
+                expert=expert,
+                optimizer=torch.optim.Adam(expert.parameters()),
+                args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
+                outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
+                max_batch_size=max_batch_size,
+            )
+        timestamps["created_experts"] = time.perf_counter()
+        server = hivemind.moe.Server(
+            None,
+            experts,
+            listen_on=f"{hivemind.LOCALHOST}:{port}",
+            num_connection_handlers=num_handlers,
+            device=device,
+        )
+        server.start()
+        server.ready.wait()
+        timestamps["server_ready"] = time.perf_counter()
+        can_start.set()
+
+        for client in clients:
+            client.join()
+        timestamps["clients_finished"] = time.perf_counter()
+    except BaseException as e:
+        benchmarking_failed.set()
+        raise e
+    finally:
+        for client in clients:
+            if client.is_alive():
+                client.terminate()
+        server.shutdown()
+        timestamps["server_shutdown_finished"] = time.perf_counter()
+        server.join()
+
+    sys.stdout.flush()
+    sys.stderr.flush()
+    time_between = (
+        lambda key1, key2: abs(timestamps[key2] - timestamps[key1])
+        if (key1 in timestamps and key2 in timestamps)
+        else float("nan")
+    )
+    total_examples = batch_size * num_clients * num_batches_per_client
+
+    logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
+    logger.info(
+        f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, "
+        f"max_batch_size={max_batch_size}, expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}"
+    )
+    logger.info(
+        f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
+        f"batch_size={batch_size}, backprop={backprop}"
+    )
+
+    logger.info("Results: ")
+    logger.info(
+        f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
+        f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
+        f"{time_between('created_experts', 'server_ready') :.3f} s. networking)"
+    )
+    logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
+    logger.info(
+        f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
+        f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s."
+    )
+    logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
+    if benchmarking_failed.is_set():
+        logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
+    print_device_info(device)
+    sys.stdout.flush()
+    sys.stderr.flush()
+
+    assert not benchmarking_failed.is_set()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--preset", type=str, default="default", required=False)
+    parser.add_argument("--num_batches_per_client", type=int, default=16, required=False)
+    args = parser.parse_args()
+
+    if args.preset in ("default", "ffn_forward_backward"):
+        benchmark_throughput()
+    elif args.preset == "ffn_forward":
+        benchmark_throughput(backprop=False, num_batches_per_client=args.num_batches_per_client)
+    elif args.preset == "ffn_small_batch":
+        benchmark_throughput(
+            backprop=False,
+            num_experts=4,
+            batch_size=32,
+            max_batch_size=8192,
+            num_batches_per_client=args.num_batches_per_client,
+        )
+    elif args.preset == "ffn_small_batch_512clients":
+        benchmark_throughput(
+            backprop=True,
+            num_experts=1,
+            batch_size=1,
+            max_batch_size=8192,
+            num_clients=512,
+            num_batches_per_client=args.num_batches_per_client,
+        )
+    elif args.preset == "ffn_small_batch_512clients_32handlers":
+        benchmark_throughput(
+            backprop=True,
+            num_experts=1,
+            batch_size=1,
+            max_batch_size=8192,
+            num_handlers=32,
+            num_clients=512,
+            num_batches_per_client=args.num_batches_per_client,
+        )
+    elif args.preset == "ffn_massive":
+        increase_file_limit()
+        benchmark_throughput(
+            backprop=False,
+            num_clients=512,
+            batch_size=512,
+            max_batch_size=8192,
+            num_batches_per_client=args.num_batches_per_client,
+        )
+    elif args.preset == "minimalistic":
+        benchmark_throughput(
+            num_experts=1, num_clients=1, num_handlers=1, num_batches_per_client=args.num_batches_per_client
+        )
+    elif args.preset == "nop":
+        benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)
+    else:
+        raise ValueError(f"No such benchmark preset: {args.preset}")

+ 1 - 1
hivemind/moe/__init__.py

@@ -1,2 +1,2 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class
+from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class, ConnectionHandler

+ 78 - 22
hivemind/moe/client/expert.py

@@ -1,22 +1,33 @@
+import asyncio
+from concurrent.futures import Future
+import multiprocessing as mp
 import pickle
 import pickle
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Dict, Optional, Tuple, Awaitable
+from threading import Thread
+from queue import Queue
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
+
+#from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.utils.grpc import ChannelCache
+from hivemind.utils import nested_compare, nested_flatten, nested_pack, switch_to_uvloop
+from hivemind.p2p import P2P, PeerInfo, StubBase
+from hivemind.proto import runtime_pb2
+
 
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
 
 
-def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]):
-    """Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache"""
-    channel_options = (("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)) + extra_options
-    return ChannelCache.get_stub(endpoint, runtime_grpc.ConnectionHandlerStub, aio=False, options=channel_options)
+import hivemind
+
+#ConnectionHandlerStub = hivemind.moe.server.connection_handler.ConnectionHandler._stub_type
+
+
+def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo): # -> ConnectionHandlerStub:
+    return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
 
 
 
 
 class RemoteExpert(nn.Module):
 class RemoteExpert(nn.Module):
@@ -26,17 +37,23 @@ class RemoteExpert(nn.Module):
     Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
     Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
     :param uid: unique expert identifier
     :param uid: unique expert identifier
-    :param endpoint: network endpoint of a server that services that expert, e.g. "201.123.321.99:1337" or "[::]:8080"
     """
     """
 
 
-    def __init__(self, uid, endpoint: Endpoint):
+    def __init__(self, uid, server_peer_info: PeerInfo, p2p: Optional[P2P] = None):
         super().__init__()
         super().__init__()
-        self.uid, self.endpoint = uid, endpoint
+        self.uid, self.server_peer_info = uid, server_peer_info
         self._info = None
         self._info = None
 
 
+        if p2p is None:
+            self.p2p = _RemoteModuleCall.run_coroutine(P2P.create())
+        else:
+            self.p2p = p2p
+
+        _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
+
     @property
     @property
-    def stub(self):
-        return _get_expert_stub(self.endpoint)
+    def stub(self) -> StubBase:
+        return _get_expert_stub(self.p2p, self.server_peer_info)
 
 
     def forward(self, *args, **kwargs):
     def forward(self, *args, **kwargs):
         """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
         """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
@@ -49,31 +66,66 @@ class RemoteExpert(nn.Module):
 
 
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
-
         flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
         flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
+
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
         return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
 
 
     @property
     @property
     def info(self):
     def info(self):
         if self._info is None:
         if self._info is None:
-            outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
+            outputs = _RemoteModuleCall.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
             self._info = pickle.loads(outputs.serialized_info)
             self._info = pickle.loads(outputs.serialized_info)
         return self._info
         return self._info
 
 
     def extra_repr(self):
     def extra_repr(self):
-        return f"uid={self.uid}, endpoint={self.endpoint}"
+        return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
 
 
 
 
 class _RemoteModuleCall(torch.autograd.Function):
 class _RemoteModuleCall(torch.autograd.Function):
     """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
     """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
 
 
-    @staticmethod
+    _task_queue: Queue = Queue()
+    _event_thread: Optional[Thread] = None
+
+    @classmethod
+    def _run(cls):
+        loop = switch_to_uvloop()
+
+        async def receive_tasks():
+            while True:
+                cor, future = cls._task_queue.get()
+                try:
+                    result = await cor
+                except Exception as e:
+                    future.set_exception(e)
+                    continue
+
+                future.set_result(result)
+
+        loop.run_until_complete(receive_tasks())
+
+    @classmethod
+    def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
+        if cls._event_thread is None:
+            cls._event_thread = Thread(target=cls._run)
+            cls._event_thread.start()
+
+        future = Future()
+        cls._task_queue.put((coro, future))
+
+        if return_future:
+            return future
+
+        return future.result()
+
+    @classmethod
     def forward(
     def forward(
+        cls,
         ctx,
         ctx,
         dummy: torch.Tensor,
         dummy: torch.Tensor,
         uid: str,
         uid: str,
-        stub: runtime_grpc.ConnectionHandlerStub,
+        stub,#: ConnectionHandlerStub,
         info: Dict[str, Any],
         info: Dict[str, Any],
         *inputs: torch.Tensor,
         *inputs: torch.Tensor,
     ) -> Tuple[torch.Tensor, ...]:
     ) -> Tuple[torch.Tensor, ...]:
@@ -88,15 +140,17 @@ class _RemoteModuleCall(torch.autograd.Function):
             for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
             for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
         ]
         ]
 
 
-        outputs = stub.forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+        outputs = cls.run_coroutine(
+            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
+        )
 
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
 
 
         return tuple(deserialized_outputs)
         return tuple(deserialized_outputs)
 
 
-    @staticmethod
+    @classmethod
     @once_differentiable
     @once_differentiable
-    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
+    def backward(cls, ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
@@ -105,7 +159,9 @@ class _RemoteModuleCall(torch.autograd.Function):
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
         ]
         ]
 
 
-        grad_inputs = ctx.stub.backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+        grad_inputs = cls.run_coroutine(
+            ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
+        )
 
 
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
         return (DUMMY, None, None, None, *deserialized_grad_inputs)
         return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 0 - 140
hivemind/moe/expert.py

@@ -1,140 +0,0 @@
-import asyncio
-import multiprocessing as mp
-import pickle
-from typing import Any, Dict, Optional, Tuple
-from threading import Thread
-
-import torch
-import torch.nn as nn
-from torch.autograd.function import once_differentiable
-
-
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.utils import nested_compare, nested_flatten, nested_pack, switch_to_uvloop
-from hivemind.p2p import P2P, PeerInfo, StubBase
-from hivemind.proto import runtime_pb2
-from hivemind.moe.server import ConnectionHandler
-
-
-DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
-
-
-ConnectionHandlerStub = ConnectionHandler._stub_type
-
-
-def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo) -> ConnectionHandlerStub:
-    return ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
-
-
-class RemoteExpert(nn.Module):
-    """
-    A simple module that runs forward/backward of an expert hosted on a remote machine.
-    Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
-    Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
-    Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
-    :param uid: unique expert identifier
-    """
-
-    def __init__(self, uid, p2p: P2P, server_peer_info: PeerInfo):
-        super().__init__()
-        self.uid, self.p2p, self.server_peer_info = uid, p2p, server_peer_info
-        self._info = None
-
-        def _run():
-            self.loop = switch_to_uvloop()
-            self.loop.run_forever()
-
-        Thread(target=_run).start()
-
-    @property
-    def stub(self) -> StubBase:
-        return _get_expert_stub(self.p2p, self.server_peer_info)
-
-    def forward(self, *args, **kwargs):
-        """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
-        assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
-        kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
-
-        # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
-
-        forward_inputs = (args, kwargs)
-
-        if not nested_compare(forward_inputs, self.info["forward_schema"]):
-            raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
-        flat_outputs = _RemoteModuleCall.apply(
-            DUMMY,
-            self.uid,
-            self.stub,
-            self.loop,
-            self.info,
-            *nested_flatten(forward_inputs),
-        )
-
-        # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
-        return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
-
-    @property
-    def info(self):
-        if self._info is None:
-            outputs = asyncio.run_coroutine_threadsafe(
-                self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)),
-                self.loop
-            ).result()
-            self._info = pickle.loads(outputs.serialized_info)
-        return self._info
-
-    def extra_repr(self):
-        return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
-
-
-class _RemoteModuleCall(torch.autograd.Function):
-    """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
-
-    @staticmethod
-    def forward(
-        ctx,
-        dummy: torch.Tensor,
-        uid: str,
-        stub: ConnectionHandlerStub,
-        loop: asyncio.AbstractEventLoop,
-        info: Dict[str, Any],
-        *inputs: torch.Tensor,
-    ) -> Tuple[torch.Tensor, ...]:
-        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
-        # detach to avoid pickling the computation graph
-        inputs = tuple(tensor.cpu().detach() for tensor in inputs)
-        ctx.uid, ctx.stub, ctx.info, ctx.loop = uid, stub, info, loop
-        ctx.save_for_backward(*inputs)
-
-        serialized_tensors = [
-            serialize_torch_tensor(inp, proto.compression)
-            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
-        ]
-
-        outputs = asyncio.run_coroutine_threadsafe(
-            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
-            loop,
-        ).result()
-
-        deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
-
-        return tuple(deserialized_outputs)
-
-    @staticmethod
-    @once_differentiable
-    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
-        grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
-        inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
-        backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
-        serialized_tensors = [
-            serialize_torch_tensor(tensor, proto.compression)
-            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-        ]
-
-        grad_inputs = asyncio.run_coroutine_threadsafe(
-            ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
-            ctx.loop,
-        ).result()
-
-        deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
-        return (DUMMY, None, None, None, *deserialized_grad_inputs)