浏览代码

deleted depricated. replicate p2p per pid. dht is mandatory for server

Pavel Samygin 3 年之前
父节点
当前提交
5e057df59f

+ 48 - 25
benchmarks/benchmark_throughput.py

@@ -3,16 +3,17 @@ import multiprocessing as mp
 import random
 import random
 import sys
 import sys
 import time
 import time
+from grpc import server
 
 
 import torch
 import torch
 
 
-from hivemind.moe.client import RemoteExpert
-from hivemind.moe.server import ExpertBackend, Server
-from hivemind.moe.server.layers import name_to_block
+import hivemind
+from hivemind import P2P
+from hivemind.dht import DHT
+from hivemind.moe.client.expert import RemoteExpertWorker
+from hivemind.moe.server import layers
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.networking import LOCALHOST, get_free_port
-from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -31,10 +32,24 @@ def print_device_info(device=None):
         logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
         logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
 
 
 
 
-def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
+def client_process(
+    can_start,
+    benchmarking_failed,
+    server_peer_info,
+    num_experts,
+    batch_size,
+    hid_dim,
+    num_batches,
+    backprop=True,
+) -> None:
     torch.set_num_threads(1)
     torch.set_num_threads(1)
     can_start.wait()
     can_start.wait()
-    experts = [RemoteExpert(f"expert{i}", endpoint=f"{LOCALHOST}:{port}") for i in range(num_experts)]
+
+    p2p = RemoteExpertWorker.run_coroutine(P2P.create())
+    RemoteExpertWorker.run_coroutine(p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
+    experts = [
+        hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info, p2p=p2p) for i in range(num_experts)
+    ]
 
 
     try:
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
         dummy_batch = torch.randn(batch_size, hid_dim)
@@ -59,15 +74,13 @@ def benchmark_throughput(
     max_batch_size=None,
     max_batch_size=None,
     backprop=True,
     backprop=True,
     device=None,
     device=None,
-    port=None,
 ):
 ):
     assert (
     assert (
         not hasattr(torch.cuda, "is_initialized")
         not hasattr(torch.cuda, "is_initialized")
         or not torch.cuda.is_initialized()
         or not torch.cuda.is_initialized()
         or torch.device(device) == torch.device("cpu")
         or torch.device(device) == torch.device("cpu")
     )
     )
-    assert expert_cls in name_to_block
-    port = port or get_free_port()
+    assert expert_cls in layers.name_to_block
     max_batch_size = max_batch_size or batch_size * 4
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
     num_handlers = max(1, num_handlers or num_clients // 2)
     benchmarking_failed = mp.Event()
     benchmarking_failed = mp.Event()
@@ -75,8 +88,12 @@ def benchmark_throughput(
     timestamps = dict(started=time.perf_counter())
     timestamps = dict(started=time.perf_counter())
 
 
     try:
     try:
-        # start clients and await server
-        # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
+        server_dht = DHT(start=True)
+        server_dht_peer_info = hivemind.PeerInfo(
+            peer_id=server_dht.peer_id,
+            addrs=[addr.decapsulate("/p2p/" + addr.get("p2p")) for addr in server_dht.get_visible_maddrs()],
+        )
+
         clients = [
         clients = [
             mp.Process(
             mp.Process(
                 target=client_process,
                 target=client_process,
@@ -84,52 +101,54 @@ def benchmark_throughput(
                 args=(
                 args=(
                     can_start,
                     can_start,
                     benchmarking_failed,
                     benchmarking_failed,
-                    port,
+                    server_dht_peer_info,
                     num_experts,
                     num_experts,
                     batch_size,
                     batch_size,
                     hid_dim,
                     hid_dim,
                     num_batches_per_client,
                     num_batches_per_client,
                     backprop,
                     backprop,
                 ),
                 ),
+                daemon=True,
             )
             )
             for i in range(num_clients)
             for i in range(num_clients)
         ]
         ]
 
 
         for client in clients:
         for client in clients:
-            client.daemon = True
             client.start()
             client.start()
 
 
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
 
 
-        # start server
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         experts = {}
         experts = {}
         for i in range(num_experts):
         for i in range(num_experts):
-            expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
-            experts[f"expert{i}"] = ExpertBackend(
-                name=f"expert{i}",
+            expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
+            experts[f"expert.{i}"] = hivemind.ExpertBackend(
+                name=f"expert.{i}",
                 expert=expert,
                 expert=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
                 optimizer=torch.optim.Adam(expert.parameters()),
-                args_schema=(BatchTensorDescriptor(hid_dim),),
-                outputs_schema=BatchTensorDescriptor(hid_dim),
+                args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
+                outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
                 max_batch_size=max_batch_size,
                 max_batch_size=max_batch_size,
             )
             )
         timestamps["created_experts"] = time.perf_counter()
         timestamps["created_experts"] = time.perf_counter()
-        server = Server(
-            None,
-            experts,
-            listen_on=f"{LOCALHOST}:{port}",
+
+        server = hivemind.moe.Server(
+            dht=server_dht,
+            expert_backends=experts,
             num_connection_handlers=num_handlers,
             num_connection_handlers=num_handlers,
             device=device,
             device=device,
         )
         )
         server.start()
         server.start()
         server.ready.wait()
         server.ready.wait()
+
         timestamps["server_ready"] = time.perf_counter()
         timestamps["server_ready"] = time.perf_counter()
         can_start.set()
         can_start.set()
 
 
         for client in clients:
         for client in clients:
             client.join()
             client.join()
+
         timestamps["clients_finished"] = time.perf_counter()
         timestamps["clients_finished"] = time.perf_counter()
+
     except BaseException as e:
     except BaseException as e:
         benchmarking_failed.set()
         benchmarking_failed.set()
         raise e
         raise e
@@ -229,7 +248,11 @@ if __name__ == "__main__":
         )
         )
     elif args.preset == "minimalistic":
     elif args.preset == "minimalistic":
         benchmark_throughput(
         benchmark_throughput(
-            num_experts=1, num_clients=1, num_handlers=1, num_batches_per_client=args.num_batches_per_client
+            num_experts=1,
+            num_clients=1,
+            num_handlers=1,
+            num_batches_per_client=args.num_batches_per_client,
+            batch_size=1024,
         )
         )
     elif args.preset == "nop":
     elif args.preset == "nop":
         benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)
         benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)

+ 0 - 257
benchmarks/benchmark_throughput_p2p.py

@@ -1,257 +0,0 @@
-import argparse
-import multiprocessing as mp
-import random
-import sys
-import time
-
-import torch
-
-import hivemind
-from hivemind import P2P
-from hivemind.dht import DHT
-from hivemind.moe.server import layers
-from hivemind.utils.limits import increase_file_limit
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__name__)
-
-
-def print_device_info(device=None):
-    """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
-    device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
-    logger.info(f"Using device: {device}")
-
-    # Additional Info when using cuda
-    if device.type == "cuda":
-        logger.info(torch.cuda.get_device_name(0))
-        logger.info(f"Memory Usage:")
-        logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
-        logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
-
-
-def client_process(
-    can_start,
-    benchmarking_failed,
-    server_peer_info,
-    num_experts,
-    batch_size,
-    hid_dim,
-    num_batches,
-    backprop=True,
-) -> None:
-    torch.set_num_threads(1)
-    can_start.wait()
-
-    p2p = hivemind.moe.client.expert._RemoteModuleCall.run_coroutine(P2P.create())
-    experts = [
-        hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info, p2p=p2p) for i in range(num_experts)
-    ]
-
-    try:
-        dummy_batch = torch.randn(batch_size, hid_dim)
-        for batch_i in range(num_batches):
-            expert = random.choice(experts)
-            out = expert(dummy_batch)
-            if backprop:
-                out.sum().backward()
-    except BaseException as e:
-        benchmarking_failed.set()
-        raise e
-
-
-def benchmark_throughput(
-    num_experts=16,
-    num_handlers=None,
-    num_clients=128,
-    num_batches_per_client=16,
-    expert_cls="ffn",
-    hid_dim=1024,
-    batch_size=2048,
-    max_batch_size=None,
-    backprop=True,
-    device=None,
-):
-    assert (
-        not hasattr(torch.cuda, "is_initialized")
-        or not torch.cuda.is_initialized()
-        or torch.device(device) == torch.device("cpu")
-    )
-    assert expert_cls in layers.name_to_block
-    max_batch_size = max_batch_size or batch_size * 4
-    num_handlers = max(1, num_handlers or num_clients // 2)
-    benchmarking_failed = mp.Event()
-    can_start = mp.Event()
-    timestamps = dict(started=time.perf_counter())
-
-    try:
-        server_dht = DHT(start=True)
-        server_dht_peer_info = hivemind.PeerInfo(
-            peer_id=server_dht.peer_id,
-            addrs=[addr.decapsulate("/p2p/" + addr.get("p2p")) for addr in server_dht.get_visible_maddrs()],
-        )
-
-        clients = [
-            mp.Process(
-                target=client_process,
-                name=f"client_process-{i}",
-                args=(
-                    can_start,
-                    benchmarking_failed,
-                    server_dht_peer_info,
-                    num_experts,
-                    batch_size,
-                    hid_dim,
-                    num_batches_per_client,
-                    backprop,
-                ),
-                daemon=True,
-            )
-            for i in range(num_clients)
-        ]
-
-        for client in clients:
-            client.start()
-
-        timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
-
-        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-        experts = {}
-        for i in range(num_experts):
-            expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
-            experts[f"expert.{i}"] = hivemind.ExpertBackend(
-                name=f"expert.{i}",
-                expert=expert,
-                optimizer=torch.optim.Adam(expert.parameters()),
-                args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
-                outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
-                max_batch_size=max_batch_size,
-            )
-        timestamps["created_experts"] = time.perf_counter()
-
-        server = hivemind.moe.Server(
-            dht=server_dht,
-            expert_backends=experts,
-            num_connection_handlers=num_handlers,
-            device=device,
-        )
-        server.start()
-        server.ready.wait()
-
-        timestamps["server_ready"] = time.perf_counter()
-        can_start.set()
-
-        for client in clients:
-            client.join()
-
-        timestamps["clients_finished"] = time.perf_counter()
-
-    except BaseException as e:
-        benchmarking_failed.set()
-        raise e
-    finally:
-        for client in clients:
-            if client.is_alive():
-                client.terminate()
-        server.shutdown()
-        timestamps["server_shutdown_finished"] = time.perf_counter()
-        server.join()
-
-    sys.stdout.flush()
-    sys.stderr.flush()
-    time_between = (
-        lambda key1, key2: abs(timestamps[key2] - timestamps[key1])
-        if (key1 in timestamps and key2 in timestamps)
-        else float("nan")
-    )
-    total_examples = batch_size * num_clients * num_batches_per_client
-
-    logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
-    logger.info(
-        f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, "
-        f"max_batch_size={max_batch_size}, expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}"
-    )
-    logger.info(
-        f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
-        f"batch_size={batch_size}, backprop={backprop}"
-    )
-
-    logger.info("Results: ")
-    logger.info(
-        f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
-        f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
-        f"{time_between('created_experts', 'server_ready') :.3f} s. networking)"
-    )
-    logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
-    logger.info(
-        f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
-        f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s."
-    )
-    logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
-    if benchmarking_failed.is_set():
-        logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
-    print_device_info(device)
-    sys.stdout.flush()
-    sys.stderr.flush()
-
-    assert not benchmarking_failed.is_set()
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--preset", type=str, default="default", required=False)
-    parser.add_argument("--num_batches_per_client", type=int, default=16, required=False)
-    args = parser.parse_args()
-
-    if args.preset in ("default", "ffn_forward_backward"):
-        benchmark_throughput()
-    elif args.preset == "ffn_forward":
-        benchmark_throughput(backprop=False, num_batches_per_client=args.num_batches_per_client)
-    elif args.preset == "ffn_small_batch":
-        benchmark_throughput(
-            backprop=False,
-            num_experts=4,
-            batch_size=32,
-            max_batch_size=8192,
-            num_batches_per_client=args.num_batches_per_client,
-        )
-    elif args.preset == "ffn_small_batch_512clients":
-        benchmark_throughput(
-            backprop=True,
-            num_experts=1,
-            batch_size=1,
-            max_batch_size=8192,
-            num_clients=512,
-            num_batches_per_client=args.num_batches_per_client,
-        )
-    elif args.preset == "ffn_small_batch_512clients_32handlers":
-        benchmark_throughput(
-            backprop=True,
-            num_experts=1,
-            batch_size=1,
-            max_batch_size=8192,
-            num_handlers=32,
-            num_clients=512,
-            num_batches_per_client=args.num_batches_per_client,
-        )
-    elif args.preset == "ffn_massive":
-        increase_file_limit()
-        benchmark_throughput(
-            backprop=False,
-            num_clients=512,
-            batch_size=512,
-            max_batch_size=8192,
-            num_batches_per_client=args.num_batches_per_client,
-        )
-    elif args.preset == "minimalistic":
-        benchmark_throughput(
-            num_experts=1,
-            num_clients=1,
-            num_handlers=1,
-            num_batches_per_client=args.num_batches_per_client,
-            batch_size=1024,
-        )
-    elif args.preset == "nop":
-        benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)
-    else:
-        raise ValueError(f"No such benchmark preset: {args.preset}")

+ 3 - 1
hivemind/dht/dht.py

@@ -55,6 +55,7 @@ class DHT(mp.Process):
         **kwargs,
         **kwargs,
     ):
     ):
         self._parent_pid = os.getpid()
         self._parent_pid = os.getpid()
+        self._my_pid = os.getpid()
         super().__init__()
         super().__init__()
 
 
         if not (
         if not (
@@ -310,7 +311,8 @@ class DHT(mp.Process):
         The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
         The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
         """
         """
 
 
-        if self._p2p_replica is None:
+        if self._p2p_replica is None or self._my_pid != os.getpid():
+            self._my_pid = os.getpid()
             daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
             daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
             self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
             self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
         return self._p2p_replica
         return self._p2p_replica

二进制
hivemind/hivemind_cli/p2pd_old


二进制
hivemind/hivemind_cli/p2pd_old2


+ 0 - 1
hivemind/hivemind_cli/run_server.py

@@ -48,7 +48,6 @@ def main():
     parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
     parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
 
 
-    parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
                         help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
                         help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
     parser.add_argument('--increase_file_limit', action='store_true',
     parser.add_argument('--increase_file_limit', action='store_true',

+ 59 - 19
hivemind/moe/client/expert.py

@@ -1,16 +1,22 @@
+from dataclasses import dataclass
 from concurrent.futures import Future
 from concurrent.futures import Future
+from lib2to3.pgen2.token import OP
+from multiaddr import Multiaddr
+import os
 from queue import Queue
 from queue import Queue
 from threading import Thread
 from threading import Thread
-from typing import Any, Awaitable, Dict, List, Optional, Tuple
+from typing import Any, Awaitable, Dict, List, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
 import hivemind
 import hivemind
+from hivemind.dht import DHT
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils import (
 from hivemind.utils import (
     MSGPackSerializer,
     MSGPackSerializer,
@@ -22,6 +28,7 @@ from hivemind.utils import (
     switch_to_uvloop,
     switch_to_uvloop,
 )
 )
 from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
 from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
+from hivemind.utils.mpfuture import MPFuture
 
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
@@ -29,6 +36,19 @@ DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autogra
 def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo):  # -> ConnectionHandlerStub:
 def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo):  # -> ConnectionHandlerStub:
     return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
     return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
 
 
+@dataclass(frozen=True)
+class RemoteExpertInfo:
+    uid: str
+    peer_id: str
+    addrs: Sequence[str]
+
+    @property
+    def as_peer_info(self) -> Tuple[str, PeerInfo]:
+        return self.uid, PeerInfo(
+            peer_id=PeerID.from_base58(self.peer_id),
+            addrs=tuple(Multiaddr(a) for a in self.addrs)
+        )
+
 
 
 class RemoteExpert(nn.Module):
 class RemoteExpert(nn.Module):
     """
     """
@@ -39,17 +59,11 @@ class RemoteExpert(nn.Module):
     :param uid: unique expert identifier
     :param uid: unique expert identifier
     """
     """
 
 
-    def __init__(self, uid, server_peer_info: PeerInfo, p2p: Optional[P2P] = None, connect: bool = True):
+    def __init__(self, uid, server_peer_info: PeerInfo, p2p: P2P):
         super().__init__()
         super().__init__()
-        self.uid, self.server_peer_info = uid, server_peer_info
+        self.uid, self.server_peer_info, self.p2p = uid, server_peer_info, p2p
         self._info = None
         self._info = None
 
 
-        if p2p is None:
-            self.p2p = _RemoteModuleCall.run_coroutine(P2P.create())
-            _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
-        else:
-            self.p2p = p2p
-
     @property
     @property
     def stub(self) -> StubBase:
     def stub(self) -> StubBase:
         return _get_expert_stub(self.p2p, self.server_peer_info)
         return _get_expert_stub(self.p2p, self.server_peer_info)
@@ -74,7 +88,7 @@ class RemoteExpert(nn.Module):
     @property
     @property
     def info(self):
     def info(self):
         if self._info is None:
         if self._info is None:
-            outputs = _RemoteModuleCall.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
+            outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
             self._info = MSGPackSerializer.loads(outputs.serialized_info)
             self._info = MSGPackSerializer.loads(outputs.serialized_info)
         return self._info
         return self._info
 
 
@@ -82,11 +96,13 @@ class RemoteExpert(nn.Module):
         return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
         return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
 
 
 
 
-class _RemoteModuleCall(torch.autograd.Function):
-    """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
+class RemoteExpertWorker:
+    """Local thread for managing async tasks related to RemoteExpert"""
 
 
     _task_queue: Queue = Queue()
     _task_queue: Queue = Queue()
     _event_thread: Optional[Thread] = None
     _event_thread: Optional[Thread] = None
+    _pid: int = 0
+
 
 
     @classmethod
     @classmethod
     def _run(cls):
     def _run(cls):
@@ -106,7 +122,8 @@ class _RemoteModuleCall(torch.autograd.Function):
 
 
     @classmethod
     @classmethod
     def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
     def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
-        if cls._event_thread is None:
+        if cls._event_thread is None or cls._pid != os.getpid():
+            cls._pid = os.getpid()
             cls._event_thread = Thread(target=cls._run, daemon=True)
             cls._event_thread = Thread(target=cls._run, daemon=True)
             cls._event_thread.start()
             cls._event_thread.start()
 
 
@@ -119,6 +136,29 @@ class _RemoteModuleCall(torch.autograd.Function):
         result = future.result()
         result = future.result()
         return result
         return result
 
 
+    @classmethod
+    def spawn_experts_future(cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT) -> MPFuture[List[Optional[RemoteExpert]]]:
+        async def _unpack():
+            return cls.spawn_experts(await infos, dht)
+        return cls.run_coroutine(_unpack, True)
+
+    @classmethod
+    def spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], dht: DHT) -> List[Optional[RemoteExpert]]:
+        p2p = cls.run_coroutine(dht.replicate_p2p())
+        experts: List[Optional[RemoteExpert]] = []
+        for i in infos:
+            if i is not None:
+                uid, peer_info = i.as_peer_info
+                experts.append(RemoteExpert(uid, peer_info, p2p))
+            else:
+                experts.append(None)
+        return experts
+
+
+
+class _RemoteModuleCall(torch.autograd.Function):
+    """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
+
     @classmethod
     @classmethod
     def forward(
     def forward(
         cls,
         cls,
@@ -155,7 +195,7 @@ class _RemoteModuleCall(torch.autograd.Function):
     def forward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
     def forward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
         split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
         split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
 
 
-        outputs = cls.run_coroutine(
+        outputs = RemoteExpertWorker.run_coroutine(
             stub.rpc_forward_partial(
             stub.rpc_forward_partial(
                 amap_in_executor(
                 amap_in_executor(
                     lambda t: runtime_pb2.ExpertRequest(
                     lambda t: runtime_pb2.ExpertRequest(
@@ -169,12 +209,12 @@ class _RemoteModuleCall(torch.autograd.Function):
             )
             )
         )
         )
 
 
-        return cls.run_coroutine(gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor))
+        return RemoteExpertWorker.run_coroutine(gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor))
 
 
     @classmethod
     @classmethod
     def forward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
     def forward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
 
 
-        outputs = cls.run_coroutine(
+        outputs = RemoteExpertWorker.run_coroutine(
             stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
             stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
         )
         )
 
 
@@ -207,7 +247,7 @@ class _RemoteModuleCall(torch.autograd.Function):
     def backward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
     def backward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
         split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
         split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
 
 
-        grad_inputs = cls.run_coroutine(
+        grad_inputs = RemoteExpertWorker.run_coroutine(
             ctx.stub.rpc_backward_partial(
             ctx.stub.rpc_backward_partial(
                 amap_in_executor(
                 amap_in_executor(
                     lambda t: runtime_pb2.ExpertRequest(
                     lambda t: runtime_pb2.ExpertRequest(
@@ -221,12 +261,12 @@ class _RemoteModuleCall(torch.autograd.Function):
             )
             )
         )
         )
 
 
-        return cls.run_coroutine(gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor))
+        return RemoteExpertWorker.run_coroutine(gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor))
 
 
     @classmethod
     @classmethod
     @once_differentiable
     @once_differentiable
     def backward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
     def backward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
-        grad_inputs = cls.run_coroutine(
+        grad_inputs = RemoteExpertWorker.run_coroutine(
             ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
             ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
         )
         )
 
 

+ 5 - 11
hivemind/moe/server/dht_handler.py

@@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
-from hivemind.moe.client.expert import RemoteExpert, _RemoteModuleCall
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
 from hivemind.moe.server.expert_uid import (
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     FLAT_EXPERT,
     UID_DELIMITER,
     UID_DELIMITER,
@@ -85,20 +85,15 @@ def get_experts(
     :returns: a list of [RemoteExpert if found else None]
     :returns: a list of [RemoteExpert if found else None]
     """
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-    p2p = _RemoteModuleCall.run_coroutine(dht.replicate_p2p())
     result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
     result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
-
-    def _unwrap_experts(vals: List[Optional[LazyValue[RemoteExpert]]]) -> List[Optional[RemoteExpert]]:
-        return [val.get(p2p=p2p) if val is not None else None for val in vals]
-
     if return_future:
     if return_future:
-        return LazyFutureCaller(result, _unwrap_experts)
-    return _unwrap_experts(result)
+        return RemoteExpertWorker.spawn_experts_future(result, dht)
+    return RemoteExpertWorker.spawn_experts(result, dht)
 
 
 
 
 async def _get_experts(
 async def _get_experts(
     dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
     dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
-) -> List[Optional[LazyValue[RemoteExpert]]]:
+) -> List[Optional[RemoteExpertInfo]]:
     if expiration_time is None:
     if expiration_time is None:
         expiration_time = get_dht_time()
         expiration_time = get_dht_time()
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
@@ -109,6 +104,5 @@ async def _get_experts(
         elem = found[uid]
         elem = found[uid]
         if elem is not None and isinstance(elem.value, tuple):
         if elem is not None and isinstance(elem.value, tuple):
             peer_id, addrs = elem.value
             peer_id, addrs = elem.value
-            peer_info = PeerInfo(peer_id=PeerID.from_base58(peer_id), addrs=tuple(Multiaddr(a) for a in addrs))
-            experts[i] = LazyValue(init=partial(RemoteExpert, uid=uid, server_peer_info=peer_info))
+            experts[i] = RemoteExpertInfo(uid, peer_id, addrs)
     return experts
     return experts

+ 17 - 24
hivemind/moe/server/server.py

@@ -42,8 +42,7 @@ class Server(threading.Thread):
      - publishes updates to expert status every :update_period: seconds
      - publishes updates to expert status every :update_period: seconds
      - follows orders from HivemindController - if it exists
      - follows orders from HivemindController - if it exists
 
 
-    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
-     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
+    :type dht: DHT.
     :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
     :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
     :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
     :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
     :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
     :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
@@ -56,7 +55,7 @@ class Server(threading.Thread):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        dht: Optional[DHT],
+        dht: DHT,
         expert_backends: Dict[str, ExpertBackend],
         expert_backends: Dict[str, ExpertBackend],
         num_connection_handlers: int = 1,
         num_connection_handlers: int = 1,
         update_period: int = 30,
         update_period: int = 30,
@@ -74,7 +73,7 @@ class Server(threading.Thread):
             self.checkpoint_saver = None
             self.checkpoint_saver = None
         self.runtime = Runtime(self.experts, **kwargs)
         self.runtime = Runtime(self.experts, **kwargs)
 
 
-        if self.dht and self.experts:
+        if self.experts:
             self.dht_handler_thread = DHTHandlerThread(
             self.dht_handler_thread = DHTHandlerThread(
                 experts=self.experts,
                 experts=self.experts,
                 dht=self.dht,
                 dht=self.dht,
@@ -103,7 +102,6 @@ class Server(threading.Thread):
         min_batch_size=1,
         min_batch_size=1,
         max_batch_size=4096,
         max_batch_size=4096,
         device=None,
         device=None,
-        no_dht=False,
         initial_peers=(),
         initial_peers=(),
         checkpoint_dir: Optional[Path] = None,
         checkpoint_dir: Optional[Path] = None,
         compression=CompressionType.NONE,
         compression=CompressionType.NONE,
@@ -132,7 +130,6 @@ class Server(threading.Thread):
         :param num_total_steps: the total number of steps for LR schedule
         :param num_total_steps: the total number of steps for LR schedule
         :param clip_grad_norm: maximum gradient norm used for clipping
         :param clip_grad_norm: maximum gradient norm used for clipping
 
 
-        :param no_dht: if specified, the server will not be attached to a dht
         :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
         :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
 
 
         :param checkpoint_dir: directory to save and load expert checkpoints
         :param checkpoint_dir: directory to save and load expert checkpoints
@@ -148,12 +145,9 @@ class Server(threading.Thread):
             add_custom_models_from_file(custom_module_path)
             add_custom_models_from_file(custom_module_path)
         assert expert_cls in name_to_block
         assert expert_cls in name_to_block
 
 
-        if no_dht:
-            dht = None
-        else:
-            dht = DHT(initial_peers=initial_peers, start=True)
-            visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
-            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+        dht = DHT(initial_peers=initial_peers, start=True)
+        visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
+        logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
 
         assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
         assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
             num_experts is not None and expert_uids is None
             num_experts is not None and expert_uids is None
@@ -234,12 +228,12 @@ class Server(threading.Thread):
             num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
             num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
             logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
             logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
 
 
-        if self.dht:
-            if not self.dht.is_alive():
-                self.dht.run_in_background(await_ready=True)
+        if not self.dht.is_alive():
+            self.dht.run_in_background(await_ready=True)
+
+        if self.experts:
+            self.dht_handler_thread.start()
 
 
-            if self.experts:
-                self.dht_handler_thread.start()
         if self.checkpoint_saver is not None:
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
             self.checkpoint_saver.start()
 
 
@@ -288,7 +282,7 @@ class Server(threading.Thread):
             process.join()
             process.join()
         logger.debug("Connection handlers terminated")
         logger.debug("Connection handlers terminated")
 
 
-        if self.dht and self.experts:
+        if self.experts:
             self.dht_handler_thread.stop.set()
             self.dht_handler_thread.stop.set()
             self.dht_handler_thread.join()
             self.dht_handler_thread.join()
 
 
@@ -296,9 +290,8 @@ class Server(threading.Thread):
             self.checkpoint_saver.stop.set()
             self.checkpoint_saver.stop.set()
             self.checkpoint_saver.join()
             self.checkpoint_saver.join()
 
 
-        if self.dht is not None:
-            self.dht.shutdown()
-            self.dht.join()
+        self.dht.shutdown()
+        self.dht.join()
 
 
         logger.debug(f"Shutting down runtime")
         logger.debug(f"Shutting down runtime")
 
 
@@ -314,7 +307,7 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[Endpoint, Li
     try:
     try:
         runner.start()
         runner.start()
         # once the server is ready, runner will send us
         # once the server is ready, runner will send us
-        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
+        # either (False, exception) or (True, (dht_peer_id, dht_maddrs))
         start_ok, data = pipe.recv()
         start_ok, data = pipe.recv()
         if start_ok:
         if start_ok:
             yield data
             yield data
@@ -338,8 +331,8 @@ def _server_runner(pipe, *args, **kwargs):
         return
         return
 
 
     try:
     try:
-        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
-        pipe.send((True, (server.listen_on, dht_maddrs)))
+        dht_maddrs = server.dht.get_visible_maddrs()
+        pipe.send((True, (server.dht.peer_id, dht_maddrs)))
         pipe.recv()  # wait for shutdown signal
         pipe.recv()  # wait for shutdown signal
 
 
     finally:
     finally: