Pārlūkot izejas kodu

deleted depricated. replicate p2p per pid. dht is mandatory for server

Pavel Samygin 3 gadi atpakaļ
vecāks
revīzija
5e057df59f

+ 48 - 25
benchmarks/benchmark_throughput.py

@@ -3,16 +3,17 @@ import multiprocessing as mp
 import random
 import sys
 import time
+from grpc import server
 
 import torch
 
-from hivemind.moe.client import RemoteExpert
-from hivemind.moe.server import ExpertBackend, Server
-from hivemind.moe.server.layers import name_to_block
+import hivemind
+from hivemind import P2P
+from hivemind.dht import DHT
+from hivemind.moe.client.expert import RemoteExpertWorker
+from hivemind.moe.server import layers
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.networking import LOCALHOST, get_free_port
-from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -31,10 +32,24 @@ def print_device_info(device=None):
         logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
 
 
-def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
+def client_process(
+    can_start,
+    benchmarking_failed,
+    server_peer_info,
+    num_experts,
+    batch_size,
+    hid_dim,
+    num_batches,
+    backprop=True,
+) -> None:
     torch.set_num_threads(1)
     can_start.wait()
-    experts = [RemoteExpert(f"expert{i}", endpoint=f"{LOCALHOST}:{port}") for i in range(num_experts)]
+
+    p2p = RemoteExpertWorker.run_coroutine(P2P.create())
+    RemoteExpertWorker.run_coroutine(p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
+    experts = [
+        hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info, p2p=p2p) for i in range(num_experts)
+    ]
 
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
@@ -59,15 +74,13 @@ def benchmark_throughput(
     max_batch_size=None,
     backprop=True,
     device=None,
-    port=None,
 ):
     assert (
         not hasattr(torch.cuda, "is_initialized")
         or not torch.cuda.is_initialized()
         or torch.device(device) == torch.device("cpu")
     )
-    assert expert_cls in name_to_block
-    port = port or get_free_port()
+    assert expert_cls in layers.name_to_block
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
     benchmarking_failed = mp.Event()
@@ -75,8 +88,12 @@ def benchmark_throughput(
     timestamps = dict(started=time.perf_counter())
 
     try:
-        # start clients and await server
-        # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
+        server_dht = DHT(start=True)
+        server_dht_peer_info = hivemind.PeerInfo(
+            peer_id=server_dht.peer_id,
+            addrs=[addr.decapsulate("/p2p/" + addr.get("p2p")) for addr in server_dht.get_visible_maddrs()],
+        )
+
         clients = [
             mp.Process(
                 target=client_process,
@@ -84,52 +101,54 @@ def benchmark_throughput(
                 args=(
                     can_start,
                     benchmarking_failed,
-                    port,
+                    server_dht_peer_info,
                     num_experts,
                     batch_size,
                     hid_dim,
                     num_batches_per_client,
                     backprop,
                 ),
+                daemon=True,
             )
             for i in range(num_clients)
         ]
 
         for client in clients:
-            client.daemon = True
             client.start()
 
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
 
-        # start server
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         experts = {}
         for i in range(num_experts):
-            expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
-            experts[f"expert{i}"] = ExpertBackend(
-                name=f"expert{i}",
+            expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
+            experts[f"expert.{i}"] = hivemind.ExpertBackend(
+                name=f"expert.{i}",
                 expert=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
-                args_schema=(BatchTensorDescriptor(hid_dim),),
-                outputs_schema=BatchTensorDescriptor(hid_dim),
+                args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
+                outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
                 max_batch_size=max_batch_size,
             )
         timestamps["created_experts"] = time.perf_counter()
-        server = Server(
-            None,
-            experts,
-            listen_on=f"{LOCALHOST}:{port}",
+
+        server = hivemind.moe.Server(
+            dht=server_dht,
+            expert_backends=experts,
             num_connection_handlers=num_handlers,
             device=device,
         )
         server.start()
         server.ready.wait()
+
         timestamps["server_ready"] = time.perf_counter()
         can_start.set()
 
         for client in clients:
             client.join()
+
         timestamps["clients_finished"] = time.perf_counter()
+
     except BaseException as e:
         benchmarking_failed.set()
         raise e
@@ -229,7 +248,11 @@ if __name__ == "__main__":
         )
     elif args.preset == "minimalistic":
         benchmark_throughput(
-            num_experts=1, num_clients=1, num_handlers=1, num_batches_per_client=args.num_batches_per_client
+            num_experts=1,
+            num_clients=1,
+            num_handlers=1,
+            num_batches_per_client=args.num_batches_per_client,
+            batch_size=1024,
         )
     elif args.preset == "nop":
         benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)

+ 0 - 257
benchmarks/benchmark_throughput_p2p.py

@@ -1,257 +0,0 @@
-import argparse
-import multiprocessing as mp
-import random
-import sys
-import time
-
-import torch
-
-import hivemind
-from hivemind import P2P
-from hivemind.dht import DHT
-from hivemind.moe.server import layers
-from hivemind.utils.limits import increase_file_limit
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__name__)
-
-
-def print_device_info(device=None):
-    """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
-    device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
-    logger.info(f"Using device: {device}")
-
-    # Additional Info when using cuda
-    if device.type == "cuda":
-        logger.info(torch.cuda.get_device_name(0))
-        logger.info(f"Memory Usage:")
-        logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
-        logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
-
-
-def client_process(
-    can_start,
-    benchmarking_failed,
-    server_peer_info,
-    num_experts,
-    batch_size,
-    hid_dim,
-    num_batches,
-    backprop=True,
-) -> None:
-    torch.set_num_threads(1)
-    can_start.wait()
-
-    p2p = hivemind.moe.client.expert._RemoteModuleCall.run_coroutine(P2P.create())
-    experts = [
-        hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info, p2p=p2p) for i in range(num_experts)
-    ]
-
-    try:
-        dummy_batch = torch.randn(batch_size, hid_dim)
-        for batch_i in range(num_batches):
-            expert = random.choice(experts)
-            out = expert(dummy_batch)
-            if backprop:
-                out.sum().backward()
-    except BaseException as e:
-        benchmarking_failed.set()
-        raise e
-
-
-def benchmark_throughput(
-    num_experts=16,
-    num_handlers=None,
-    num_clients=128,
-    num_batches_per_client=16,
-    expert_cls="ffn",
-    hid_dim=1024,
-    batch_size=2048,
-    max_batch_size=None,
-    backprop=True,
-    device=None,
-):
-    assert (
-        not hasattr(torch.cuda, "is_initialized")
-        or not torch.cuda.is_initialized()
-        or torch.device(device) == torch.device("cpu")
-    )
-    assert expert_cls in layers.name_to_block
-    max_batch_size = max_batch_size or batch_size * 4
-    num_handlers = max(1, num_handlers or num_clients // 2)
-    benchmarking_failed = mp.Event()
-    can_start = mp.Event()
-    timestamps = dict(started=time.perf_counter())
-
-    try:
-        server_dht = DHT(start=True)
-        server_dht_peer_info = hivemind.PeerInfo(
-            peer_id=server_dht.peer_id,
-            addrs=[addr.decapsulate("/p2p/" + addr.get("p2p")) for addr in server_dht.get_visible_maddrs()],
-        )
-
-        clients = [
-            mp.Process(
-                target=client_process,
-                name=f"client_process-{i}",
-                args=(
-                    can_start,
-                    benchmarking_failed,
-                    server_dht_peer_info,
-                    num_experts,
-                    batch_size,
-                    hid_dim,
-                    num_batches_per_client,
-                    backprop,
-                ),
-                daemon=True,
-            )
-            for i in range(num_clients)
-        ]
-
-        for client in clients:
-            client.start()
-
-        timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
-
-        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-        experts = {}
-        for i in range(num_experts):
-            expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
-            experts[f"expert.{i}"] = hivemind.ExpertBackend(
-                name=f"expert.{i}",
-                expert=expert,
-                optimizer=torch.optim.Adam(expert.parameters()),
-                args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
-                outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
-                max_batch_size=max_batch_size,
-            )
-        timestamps["created_experts"] = time.perf_counter()
-
-        server = hivemind.moe.Server(
-            dht=server_dht,
-            expert_backends=experts,
-            num_connection_handlers=num_handlers,
-            device=device,
-        )
-        server.start()
-        server.ready.wait()
-
-        timestamps["server_ready"] = time.perf_counter()
-        can_start.set()
-
-        for client in clients:
-            client.join()
-
-        timestamps["clients_finished"] = time.perf_counter()
-
-    except BaseException as e:
-        benchmarking_failed.set()
-        raise e
-    finally:
-        for client in clients:
-            if client.is_alive():
-                client.terminate()
-        server.shutdown()
-        timestamps["server_shutdown_finished"] = time.perf_counter()
-        server.join()
-
-    sys.stdout.flush()
-    sys.stderr.flush()
-    time_between = (
-        lambda key1, key2: abs(timestamps[key2] - timestamps[key1])
-        if (key1 in timestamps and key2 in timestamps)
-        else float("nan")
-    )
-    total_examples = batch_size * num_clients * num_batches_per_client
-
-    logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
-    logger.info(
-        f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, "
-        f"max_batch_size={max_batch_size}, expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}"
-    )
-    logger.info(
-        f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
-        f"batch_size={batch_size}, backprop={backprop}"
-    )
-
-    logger.info("Results: ")
-    logger.info(
-        f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
-        f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
-        f"{time_between('created_experts', 'server_ready') :.3f} s. networking)"
-    )
-    logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
-    logger.info(
-        f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
-        f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s."
-    )
-    logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
-    if benchmarking_failed.is_set():
-        logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
-    print_device_info(device)
-    sys.stdout.flush()
-    sys.stderr.flush()
-
-    assert not benchmarking_failed.is_set()
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--preset", type=str, default="default", required=False)
-    parser.add_argument("--num_batches_per_client", type=int, default=16, required=False)
-    args = parser.parse_args()
-
-    if args.preset in ("default", "ffn_forward_backward"):
-        benchmark_throughput()
-    elif args.preset == "ffn_forward":
-        benchmark_throughput(backprop=False, num_batches_per_client=args.num_batches_per_client)
-    elif args.preset == "ffn_small_batch":
-        benchmark_throughput(
-            backprop=False,
-            num_experts=4,
-            batch_size=32,
-            max_batch_size=8192,
-            num_batches_per_client=args.num_batches_per_client,
-        )
-    elif args.preset == "ffn_small_batch_512clients":
-        benchmark_throughput(
-            backprop=True,
-            num_experts=1,
-            batch_size=1,
-            max_batch_size=8192,
-            num_clients=512,
-            num_batches_per_client=args.num_batches_per_client,
-        )
-    elif args.preset == "ffn_small_batch_512clients_32handlers":
-        benchmark_throughput(
-            backprop=True,
-            num_experts=1,
-            batch_size=1,
-            max_batch_size=8192,
-            num_handlers=32,
-            num_clients=512,
-            num_batches_per_client=args.num_batches_per_client,
-        )
-    elif args.preset == "ffn_massive":
-        increase_file_limit()
-        benchmark_throughput(
-            backprop=False,
-            num_clients=512,
-            batch_size=512,
-            max_batch_size=8192,
-            num_batches_per_client=args.num_batches_per_client,
-        )
-    elif args.preset == "minimalistic":
-        benchmark_throughput(
-            num_experts=1,
-            num_clients=1,
-            num_handlers=1,
-            num_batches_per_client=args.num_batches_per_client,
-            batch_size=1024,
-        )
-    elif args.preset == "nop":
-        benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)
-    else:
-        raise ValueError(f"No such benchmark preset: {args.preset}")

+ 3 - 1
hivemind/dht/dht.py

@@ -55,6 +55,7 @@ class DHT(mp.Process):
         **kwargs,
     ):
         self._parent_pid = os.getpid()
+        self._my_pid = os.getpid()
         super().__init__()
 
         if not (
@@ -310,7 +311,8 @@ class DHT(mp.Process):
         The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
         """
 
-        if self._p2p_replica is None:
+        if self._p2p_replica is None or self._my_pid != os.getpid():
+            self._my_pid = os.getpid()
             daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
             self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
         return self._p2p_replica

BIN
hivemind/hivemind_cli/p2pd_old


BIN
hivemind/hivemind_cli/p2pd_old2


+ 0 - 1
hivemind/hivemind_cli/run_server.py

@@ -48,7 +48,6 @@ def main():
     parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
 
-    parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
                         help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
     parser.add_argument('--increase_file_limit', action='store_true',

+ 59 - 19
hivemind/moe/client/expert.py

@@ -1,16 +1,22 @@
+from dataclasses import dataclass
 from concurrent.futures import Future
+from lib2to3.pgen2.token import OP
+from multiaddr import Multiaddr
+import os
 from queue import Queue
 from threading import Thread
-from typing import Any, Awaitable, Dict, List, Optional, Tuple
+from typing import Any, Awaitable, Dict, List, Optional, Sequence, Tuple, Union
 
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
 import hivemind
+from hivemind.dht import DHT
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
 from hivemind.proto import runtime_pb2
 from hivemind.utils import (
     MSGPackSerializer,
@@ -22,6 +28,7 @@ from hivemind.utils import (
     switch_to_uvloop,
 )
 from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
+from hivemind.utils.mpfuture import MPFuture
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
@@ -29,6 +36,19 @@ DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autogra
 def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo):  # -> ConnectionHandlerStub:
     return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
 
+@dataclass(frozen=True)
+class RemoteExpertInfo:
+    uid: str
+    peer_id: str
+    addrs: Sequence[str]
+
+    @property
+    def as_peer_info(self) -> Tuple[str, PeerInfo]:
+        return self.uid, PeerInfo(
+            peer_id=PeerID.from_base58(self.peer_id),
+            addrs=tuple(Multiaddr(a) for a in self.addrs)
+        )
+
 
 class RemoteExpert(nn.Module):
     """
@@ -39,17 +59,11 @@ class RemoteExpert(nn.Module):
     :param uid: unique expert identifier
     """
 
-    def __init__(self, uid, server_peer_info: PeerInfo, p2p: Optional[P2P] = None, connect: bool = True):
+    def __init__(self, uid, server_peer_info: PeerInfo, p2p: P2P):
         super().__init__()
-        self.uid, self.server_peer_info = uid, server_peer_info
+        self.uid, self.server_peer_info, self.p2p = uid, server_peer_info, p2p
         self._info = None
 
-        if p2p is None:
-            self.p2p = _RemoteModuleCall.run_coroutine(P2P.create())
-            _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
-        else:
-            self.p2p = p2p
-
     @property
     def stub(self) -> StubBase:
         return _get_expert_stub(self.p2p, self.server_peer_info)
@@ -74,7 +88,7 @@ class RemoteExpert(nn.Module):
     @property
     def info(self):
         if self._info is None:
-            outputs = _RemoteModuleCall.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
+            outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
             self._info = MSGPackSerializer.loads(outputs.serialized_info)
         return self._info
 
@@ -82,11 +96,13 @@ class RemoteExpert(nn.Module):
         return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
 
 
-class _RemoteModuleCall(torch.autograd.Function):
-    """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
+class RemoteExpertWorker:
+    """Local thread for managing async tasks related to RemoteExpert"""
 
     _task_queue: Queue = Queue()
     _event_thread: Optional[Thread] = None
+    _pid: int = 0
+
 
     @classmethod
     def _run(cls):
@@ -106,7 +122,8 @@ class _RemoteModuleCall(torch.autograd.Function):
 
     @classmethod
     def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
-        if cls._event_thread is None:
+        if cls._event_thread is None or cls._pid != os.getpid():
+            cls._pid = os.getpid()
             cls._event_thread = Thread(target=cls._run, daemon=True)
             cls._event_thread.start()
 
@@ -119,6 +136,29 @@ class _RemoteModuleCall(torch.autograd.Function):
         result = future.result()
         return result
 
+    @classmethod
+    def spawn_experts_future(cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT) -> MPFuture[List[Optional[RemoteExpert]]]:
+        async def _unpack():
+            return cls.spawn_experts(await infos, dht)
+        return cls.run_coroutine(_unpack, True)
+
+    @classmethod
+    def spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], dht: DHT) -> List[Optional[RemoteExpert]]:
+        p2p = cls.run_coroutine(dht.replicate_p2p())
+        experts: List[Optional[RemoteExpert]] = []
+        for i in infos:
+            if i is not None:
+                uid, peer_info = i.as_peer_info
+                experts.append(RemoteExpert(uid, peer_info, p2p))
+            else:
+                experts.append(None)
+        return experts
+
+
+
+class _RemoteModuleCall(torch.autograd.Function):
+    """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
+
     @classmethod
     def forward(
         cls,
@@ -155,7 +195,7 @@ class _RemoteModuleCall(torch.autograd.Function):
     def forward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
         split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
 
-        outputs = cls.run_coroutine(
+        outputs = RemoteExpertWorker.run_coroutine(
             stub.rpc_forward_partial(
                 amap_in_executor(
                     lambda t: runtime_pb2.ExpertRequest(
@@ -169,12 +209,12 @@ class _RemoteModuleCall(torch.autograd.Function):
             )
         )
 
-        return cls.run_coroutine(gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor))
+        return RemoteExpertWorker.run_coroutine(gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor))
 
     @classmethod
     def forward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
 
-        outputs = cls.run_coroutine(
+        outputs = RemoteExpertWorker.run_coroutine(
             stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
         )
 
@@ -207,7 +247,7 @@ class _RemoteModuleCall(torch.autograd.Function):
     def backward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
         split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
 
-        grad_inputs = cls.run_coroutine(
+        grad_inputs = RemoteExpertWorker.run_coroutine(
             ctx.stub.rpc_backward_partial(
                 amap_in_executor(
                     lambda t: runtime_pb2.ExpertRequest(
@@ -221,12 +261,12 @@ class _RemoteModuleCall(torch.autograd.Function):
             )
         )
 
-        return cls.run_coroutine(gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor))
+        return RemoteExpertWorker.run_coroutine(gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor))
 
     @classmethod
     @once_differentiable
     def backward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
-        grad_inputs = cls.run_coroutine(
+        grad_inputs = RemoteExpertWorker.run_coroutine(
             ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
         )
 

+ 5 - 11
hivemind/moe/server/dht_handler.py

@@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
 from multiaddr import Multiaddr
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
-from hivemind.moe.client.expert import RemoteExpert, _RemoteModuleCall
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     UID_DELIMITER,
@@ -85,20 +85,15 @@ def get_experts(
     :returns: a list of [RemoteExpert if found else None]
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-    p2p = _RemoteModuleCall.run_coroutine(dht.replicate_p2p())
     result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
-
-    def _unwrap_experts(vals: List[Optional[LazyValue[RemoteExpert]]]) -> List[Optional[RemoteExpert]]:
-        return [val.get(p2p=p2p) if val is not None else None for val in vals]
-
     if return_future:
-        return LazyFutureCaller(result, _unwrap_experts)
-    return _unwrap_experts(result)
+        return RemoteExpertWorker.spawn_experts_future(result, dht)
+    return RemoteExpertWorker.spawn_experts(result, dht)
 
 
 async def _get_experts(
     dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
-) -> List[Optional[LazyValue[RemoteExpert]]]:
+) -> List[Optional[RemoteExpertInfo]]:
     if expiration_time is None:
         expiration_time = get_dht_time()
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
@@ -109,6 +104,5 @@ async def _get_experts(
         elem = found[uid]
         if elem is not None and isinstance(elem.value, tuple):
             peer_id, addrs = elem.value
-            peer_info = PeerInfo(peer_id=PeerID.from_base58(peer_id), addrs=tuple(Multiaddr(a) for a in addrs))
-            experts[i] = LazyValue(init=partial(RemoteExpert, uid=uid, server_peer_info=peer_info))
+            experts[i] = RemoteExpertInfo(uid, peer_id, addrs)
     return experts

+ 17 - 24
hivemind/moe/server/server.py

@@ -42,8 +42,7 @@ class Server(threading.Thread):
      - publishes updates to expert status every :update_period: seconds
      - follows orders from HivemindController - if it exists
 
-    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
-     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
+    :type dht: DHT.
     :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
     :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
     :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
@@ -56,7 +55,7 @@ class Server(threading.Thread):
 
     def __init__(
         self,
-        dht: Optional[DHT],
+        dht: DHT,
         expert_backends: Dict[str, ExpertBackend],
         num_connection_handlers: int = 1,
         update_period: int = 30,
@@ -74,7 +73,7 @@ class Server(threading.Thread):
             self.checkpoint_saver = None
         self.runtime = Runtime(self.experts, **kwargs)
 
-        if self.dht and self.experts:
+        if self.experts:
             self.dht_handler_thread = DHTHandlerThread(
                 experts=self.experts,
                 dht=self.dht,
@@ -103,7 +102,6 @@ class Server(threading.Thread):
         min_batch_size=1,
         max_batch_size=4096,
         device=None,
-        no_dht=False,
         initial_peers=(),
         checkpoint_dir: Optional[Path] = None,
         compression=CompressionType.NONE,
@@ -132,7 +130,6 @@ class Server(threading.Thread):
         :param num_total_steps: the total number of steps for LR schedule
         :param clip_grad_norm: maximum gradient norm used for clipping
 
-        :param no_dht: if specified, the server will not be attached to a dht
         :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
 
         :param checkpoint_dir: directory to save and load expert checkpoints
@@ -148,12 +145,9 @@ class Server(threading.Thread):
             add_custom_models_from_file(custom_module_path)
         assert expert_cls in name_to_block
 
-        if no_dht:
-            dht = None
-        else:
-            dht = DHT(initial_peers=initial_peers, start=True)
-            visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
-            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+        dht = DHT(initial_peers=initial_peers, start=True)
+        visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
+        logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
         assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
             num_experts is not None and expert_uids is None
@@ -234,12 +228,12 @@ class Server(threading.Thread):
             num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
             logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
 
-        if self.dht:
-            if not self.dht.is_alive():
-                self.dht.run_in_background(await_ready=True)
+        if not self.dht.is_alive():
+            self.dht.run_in_background(await_ready=True)
+
+        if self.experts:
+            self.dht_handler_thread.start()
 
-            if self.experts:
-                self.dht_handler_thread.start()
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
 
@@ -288,7 +282,7 @@ class Server(threading.Thread):
             process.join()
         logger.debug("Connection handlers terminated")
 
-        if self.dht and self.experts:
+        if self.experts:
             self.dht_handler_thread.stop.set()
             self.dht_handler_thread.join()
 
@@ -296,9 +290,8 @@ class Server(threading.Thread):
             self.checkpoint_saver.stop.set()
             self.checkpoint_saver.join()
 
-        if self.dht is not None:
-            self.dht.shutdown()
-            self.dht.join()
+        self.dht.shutdown()
+        self.dht.join()
 
         logger.debug(f"Shutting down runtime")
 
@@ -314,7 +307,7 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[Endpoint, Li
     try:
         runner.start()
         # once the server is ready, runner will send us
-        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
+        # either (False, exception) or (True, (dht_peer_id, dht_maddrs))
         start_ok, data = pipe.recv()
         if start_ok:
             yield data
@@ -338,8 +331,8 @@ def _server_runner(pipe, *args, **kwargs):
         return
 
     try:
-        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
-        pipe.send((True, (server.listen_on, dht_maddrs)))
+        dht_maddrs = server.dht.get_visible_maddrs()
+        pipe.send((True, (server.dht.peer_id, dht_maddrs)))
         pipe.recv()  # wait for shutdown signal
 
     finally: