Преглед на файлове

Merge remote-tracking branch 'lah/master' into power_ef_new

Artem Chumachenko преди 3 години
родител
ревизия
e9e7b0fa45

+ 13 - 13
benchmarks/benchmark_throughput.py

@@ -6,11 +6,13 @@ import time
 
 import torch
 
-import hivemind
-from hivemind import get_free_port
-from hivemind.moe.server import layers
+from hivemind.moe.client import RemoteExpert
+from hivemind.moe.server import ExpertBackend, Server
+from hivemind.moe.server.layers import name_to_block
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import LOCALHOST, get_free_port
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -32,9 +34,7 @@ def print_device_info(device=None):
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
     torch.set_num_threads(1)
     can_start.wait()
-    experts = [
-        hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)
-    ]
+    experts = [RemoteExpert(f"expert{i}", endpoint=f"{LOCALHOST}:{port}") for i in range(num_experts)]
 
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
@@ -66,7 +66,7 @@ def benchmark_throughput(
         or not torch.cuda.is_initialized()
         or torch.device(device) == torch.device("cpu")
     )
-    assert expert_cls in layers.name_to_block
+    assert expert_cls in name_to_block
     port = port or get_free_port()
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
@@ -105,20 +105,20 @@ def benchmark_throughput(
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         experts = {}
         for i in range(num_experts):
-            expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
-            experts[f"expert{i}"] = hivemind.ExpertBackend(
+            expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
+            experts[f"expert{i}"] = ExpertBackend(
                 name=f"expert{i}",
                 expert=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
-                args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
-                outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
+                args_schema=(BatchTensorDescriptor(hid_dim),),
+                outputs_schema=BatchTensorDescriptor(hid_dim),
                 max_batch_size=max_batch_size,
             )
         timestamps["created_experts"] = time.perf_counter()
-        server = hivemind.moe.Server(
+        server = Server(
             None,
             experts,
-            listen_on=f"{hivemind.LOCALHOST}:{port}",
+            listen_on=f"{LOCALHOST}:{port}",
             num_connection_handlers=num_handlers,
             device=device,
         )

+ 8 - 6
docs/modules/server.rst

@@ -9,9 +9,9 @@ or as a part of **hivemind.moe.client.RemoteMixtureOfExperts** that finds the mo
 The hivemind.moe.server module is organized as follows:
 
 - Server_ is the main class that publishes experts, accepts incoming requests, and passes them to Runtime_ for compute.
-- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert.
 - ExpertBackend_ is a wrapper for `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_ \
   that can be accessed by remote clients. It has two TaskPool_ s for forward and backward requests.
+- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert.
 - TaskPool_ stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches \
   and offers those batches to Runtime_ for processing.
 
@@ -25,16 +25,18 @@ The hivemind.moe.server module is organized as follows:
    :members:
    :member-order: bysource
 
-.. _Runtime:
-.. autoclass:: Runtime
-    :members:
-    :member-order: bysource
-
 .. _ExpertBackend:
 .. autoclass:: ExpertBackend
     :members: forward, backward, apply_gradients, get_info, get_pools
     :member-order: bysource
 
+.. currentmodule:: hivemind.moe.server.runtime
+
+.. _Runtime:
+.. autoclass:: Runtime
+    :members:
+    :member-order: bysource
+
 .. currentmodule:: hivemind.moe.server.task_pool
 
 .. _TaskPool:

+ 1 - 1
hivemind/hivemind_cli/run_server.py

@@ -4,7 +4,7 @@ from pathlib import Path
 import configargparse
 import torch
 
-from hivemind.moe.server import Server
+from hivemind.moe import Server
 from hivemind.moe.server.layers import schedule_name_to_scheduler
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit

+ 8 - 1
hivemind/moe/__init__.py

@@ -1,2 +1,9 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class
+from hivemind.moe.server import (
+    ExpertBackend,
+    Server,
+    background_server,
+    declare_experts,
+    get_experts,
+    register_expert_class,
+)

+ 4 - 4
hivemind/moe/client/moe.py

@@ -9,8 +9,8 @@ import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
@@ -48,7 +48,7 @@ class RemoteMixtureOfExperts(nn.Module):
         *,
         in_features,
         grid_size: Tuple[int, ...],
-        dht: hivemind.DHT,
+        dht: DHT,
         uid_prefix: str,
         k_best: int,
         k_min: int = 1,
@@ -245,7 +245,7 @@ class _RemoteCallMany(torch.autograd.Function):
         else:
             outputs_schema = info["outputs_schema"]
         outputs = nested_map(
-            lambda descriptor: descriptor.make_empty(num_samples, max_experts, device=flat_inputs[0].device).zero_(),
+            lambda descriptor: descriptor.make_zeros(num_samples, max_experts, device=flat_inputs[0].device),
             outputs_schema,
         )
 
@@ -341,7 +341,7 @@ class _RemoteCallMany(torch.autograd.Function):
         # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
 
         grad_inputs = nested_map(
-            lambda descr: descr.make_empty(num_samples, device=flat_grad_outputs[0].device).zero_(),
+            lambda descr: descr.make_zeros(num_samples, device=flat_grad_outputs[0].device),
             list(nested_flatten(info["forward_schema"])),
         )
 

+ 3 - 355
hivemind/moe/server/__init__.py

@@ -1,356 +1,4 @@
-from __future__ import annotations
-
-import multiprocessing as mp
-import multiprocessing.synchronize
-import threading
-from contextlib import contextmanager
-from functools import partial
-from pathlib import Path
-from typing import Dict, List, Optional, Tuple
-
-import torch
-from multiaddr import Multiaddr
-
-import hivemind
-from hivemind.dht import DHT
-from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
-from hivemind.moe.server.connection_handler import ConnectionHandler
-from hivemind.moe.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
+from hivemind.moe.server.dht_handler import declare_experts, get_experts
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
-from hivemind.moe.server.layers import (
-    add_custom_models_from_file,
-    name_to_block,
-    name_to_input,
-    register_expert_class,
-    schedule_name_to_scheduler,
-)
-from hivemind.moe.server.runtime import Runtime
-from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
-
-logger = get_logger(__name__)
-
-
-class Server(threading.Thread):
-    """
-    Server allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
-    After creation, a server should be started: see Server.run or Server.run_in_background.
-
-    A working server does 3 things:
-     - processes incoming forward/backward requests via Runtime (created by the server)
-     - publishes updates to expert status every :update_period: seconds
-     - follows orders from HivemindController - if it exists
-
-    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
-     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
-    :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
-    :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
-    :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
-        if too small for normal functioning, we recommend 4 handlers per expert backend.
-    :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
-        if dht is None, this parameter is ignored.
-    :param start: if True, the server will immediately start as a background thread and returns control after server
-        is ready (see .ready below)
-    """
-
-    def __init__(
-        self,
-        dht: Optional[DHT],
-        expert_backends: Dict[str, ExpertBackend],
-        listen_on: Endpoint = "0.0.0.0:*",
-        num_connection_handlers: int = 1,
-        update_period: int = 30,
-        start=False,
-        checkpoint_dir=None,
-        **kwargs,
-    ):
-        super().__init__()
-        self.dht, self.experts, self.update_period = dht, expert_backends, update_period
-        if get_port(listen_on) is None:
-            listen_on = replace_port(listen_on, new_port=get_free_port())
-        self.listen_on, self.port = listen_on, get_port(listen_on)
-
-        self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
-        if checkpoint_dir is not None:
-            self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
-        else:
-            self.checkpoint_saver = None
-        self.runtime = Runtime(self.experts, **kwargs)
-
-        if self.dht and self.experts:
-            self.dht_handler_thread = DHTHandlerThread(
-                experts=self.experts,
-                dht=self.dht,
-                endpoint=self.listen_on,
-                update_period=self.update_period,
-                daemon=True,
-            )
-
-        if start:
-            self.run_in_background(await_ready=True)
-
-    @classmethod
-    def create(
-        cls,
-        listen_on="0.0.0.0:*",
-        num_experts: int = None,
-        expert_uids: str = None,
-        expert_pattern: str = None,
-        expert_cls="ffn",
-        hidden_dim=1024,
-        optim_cls=torch.optim.Adam,
-        scheduler: str = "none",
-        num_warmup_steps=None,
-        num_total_steps=None,
-        clip_grad_norm=None,
-        num_handlers=None,
-        min_batch_size=1,
-        max_batch_size=4096,
-        device=None,
-        no_dht=False,
-        initial_peers=(),
-        checkpoint_dir: Optional[Path] = None,
-        compression=CompressionType.NONE,
-        stats_report_interval: Optional[int] = None,
-        custom_module_path=None,
-        *,
-        start: bool,
-    ) -> Server:
-        """
-        Instantiate a server with several identical experts. See argparse comments below for details
-        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
-        :param num_experts: run this many identical experts
-        :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
-           means "sample random experts between myprefix.0.0 and myprefix.255.255;
-        :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
-        :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
-        :param hidden_dim: main dimension for expert_cls
-        :param num_handlers: server will use this many parallel processes to handle incoming requests
-        :param min_batch_size: total num examples in the same batch will be greater than this value
-        :param max_batch_size: total num examples in the same batch will not exceed this value
-        :param device: all experts will use this device in torch notation; default: cuda if available else cpu
-
-        :param optim_cls: uses this optimizer to train all experts
-        :param scheduler: if not `none`, the name of the expert LR scheduler
-        :param num_warmup_steps: the number of warmup steps for LR schedule
-        :param num_total_steps: the total number of steps for LR schedule
-        :param clip_grad_norm: maximum gradient norm used for clipping
-
-        :param no_dht: if specified, the server will not be attached to a dht
-        :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
-
-        :param checkpoint_dir: directory to save and load expert checkpoints
-
-        :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
-            hosted on this server. For a more fine-grained compression, start server in python and specify compression
-            for each BatchTensorProto in ExpertBackend for the respective experts.
-
-        :param start: if True, starts server right away and returns when server is ready for requests
-        :param stats_report_interval: interval between two reports of batch processing performance statistics
-        """
-        if custom_module_path is not None:
-            add_custom_models_from_file(custom_module_path)
-        assert expert_cls in name_to_block
-
-        if no_dht:
-            dht = None
-        else:
-            dht = hivemind.DHT(initial_peers=initial_peers, start=True)
-            visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
-            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
-
-        assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
-            num_experts is not None and expert_uids is None
-        ), "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
-
-        if expert_uids is None:
-            if checkpoint_dir is not None:
-                assert is_directory(checkpoint_dir)
-                expert_uids = [
-                    child.name for child in checkpoint_dir.iterdir() if (child / "checkpoint_last.pt").exists()
-                ]
-                total_experts_in_checkpoint = len(expert_uids)
-                logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
-
-                if total_experts_in_checkpoint > num_experts:
-                    raise ValueError(
-                        f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
-                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints."
-                    )
-            else:
-                expert_uids = []
-
-            uids_to_generate = num_experts - len(expert_uids)
-            if uids_to_generate > 0:
-                logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
-                expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))
-
-        num_experts = len(expert_uids)
-        num_handlers = num_handlers if num_handlers is not None else num_experts * 8
-        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
-        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-
-        sample_input = name_to_input[expert_cls](3, hidden_dim)
-        if isinstance(sample_input, tuple):
-            args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
-        else:
-            args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
-
-        scheduler = schedule_name_to_scheduler[scheduler]
-
-        # initialize experts
-        experts = {}
-        for expert_uid in expert_uids:
-            expert = name_to_block[expert_cls](hidden_dim)
-            experts[expert_uid] = hivemind.ExpertBackend(
-                name=expert_uid,
-                expert=expert,
-                args_schema=args_schema,
-                optimizer=optim_cls(expert.parameters()),
-                scheduler=scheduler,
-                num_warmup_steps=num_warmup_steps,
-                num_total_steps=num_total_steps,
-                clip_grad_norm=clip_grad_norm,
-                min_batch_size=min_batch_size,
-                max_batch_size=max_batch_size,
-            )
-
-        if checkpoint_dir is not None:
-            load_experts(experts, checkpoint_dir)
-
-        return cls(
-            dht,
-            experts,
-            listen_on=listen_on,
-            num_connection_handlers=num_handlers,
-            device=device,
-            checkpoint_dir=checkpoint_dir,
-            stats_report_interval=stats_report_interval,
-            start=start,
-        )
-
-    def run(self):
-        """
-        Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
-        runs Runtime (self.runtime) to process incoming requests.
-        """
-        logger.info(f"Server started at {self.listen_on}")
-        logger.info(f"Got {len(self.experts)} experts:")
-        for expert_name, backend in self.experts.items():
-            num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
-            logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
-
-        if self.dht:
-            if not self.dht.is_alive():
-                self.dht.run_in_background(await_ready=True)
-
-            if self.experts:
-                self.dht_handler_thread.start()
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.start()
-
-        for process in self.conn_handlers:
-            if not process.is_alive():
-                process.start()
-            process.ready.wait()
-
-        try:
-            self.runtime.run()
-        finally:
-            self.shutdown()
-
-    def run_in_background(self, await_ready=True, timeout=None):
-        """
-        Starts Server in a background thread. if await_ready, this method will wait until background server
-        is ready to process incoming requests or for :timeout: seconds max.
-        """
-        self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
-
-    @property
-    def ready(self) -> mp.synchronize.Event:
-        """
-        An event (multiprocessing.Event) that is set when the server is ready to process requests.
-
-        Example
-        =======
-        >>> server.start()
-        >>> server.ready.wait(timeout=10)
-        >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
-        """
-        return self.runtime.ready  # mp.Event that is true if self is ready to process batches
-
-    def shutdown(self):
-        """
-        Gracefully terminate the server, process-safe.
-        Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
-        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
-        """
-        self.ready.clear()
-
-        for process in self.conn_handlers:
-            process.terminate()
-            process.join()
-        logger.debug("Connection handlers terminated")
-
-        if self.dht and self.experts:
-            self.dht_handler_thread.stop.set()
-            self.dht_handler_thread.join()
-
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.stop.set()
-            self.checkpoint_saver.join()
-
-        if self.dht is not None:
-            self.dht.shutdown()
-            self.dht.join()
-
-        logger.debug(f"Shutting down runtime")
-
-        self.runtime.shutdown()
-        logger.info("Server shutdown succesfully")
-
-
-@contextmanager
-def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, List[Multiaddr]]:
-    """A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit"""
-    pipe, runners_pipe = mp.Pipe(duplex=True)
-    runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
-    try:
-        runner.start()
-        # once the server is ready, runner will send us
-        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
-        start_ok, data = pipe.recv()
-        if start_ok:
-            yield data
-            pipe.send("SHUTDOWN")  # on exit from context, send shutdown signal
-        else:
-            raise RuntimeError(f"Server failed to start: {data}")
-    finally:
-        runner.join(timeout=shutdown_timeout)
-        if runner.is_alive():
-            logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
-            runner.kill()
-            logger.info("Server terminated")
-
-
-def _server_runner(pipe, *args, **kwargs):
-    try:
-        server = Server.create(*args, start=True, **kwargs)
-    except Exception as e:
-        logger.exception(f"Encountered an exception when starting a server: {e}")
-        pipe.send((False, f"{type(e).__name__} {e}"))
-        return
-
-    try:
-        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
-        pipe.send((True, (server.listen_on, dht_maddrs)))
-        pipe.recv()  # wait for shutdown signal
-
-    finally:
-        logger.info("Shutting down server...")
-        server.shutdown()
-        server.join()
-        logger.info("Server shut down")
+from hivemind.moe.server.layers import register_expert_class
+from hivemind.moe.server.server import Server, background_server

+ 2 - 2
hivemind/moe/server/expert_backend.py

@@ -74,8 +74,8 @@ class ExpertBackend:
 
         if outputs_schema is None:
             # run expert once to get outputs schema
-            dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema)
-            dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
+            dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema)
+            dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
             dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
             outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
 

+ 2 - 73
hivemind/moe/server/expert_uid.py

@@ -1,12 +1,7 @@
-import random
 import re
-from typing import List, NamedTuple, Optional, Tuple, Union
+from typing import NamedTuple, Tuple, Union
 
-import hivemind
-from hivemind.dht import DHT
-from hivemind.utils import Endpoint, get_logger
-
-logger = get_logger(__name__)
+from hivemind.utils import Endpoint
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
 UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("endpoint", Endpoint)])
@@ -32,69 +27,3 @@ def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPref
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
     pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1
     return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
-
-
-def generate_uids_from_pattern(
-    num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
-) -> List[str]:
-    """
-    Sample experts from a given pattern, remove duplicates.
-    :param num_experts: sample this many unique expert uids
-    :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
-     means "sample random experts between myprefix.0.0 and myprefix.255.255;
-    :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
-    :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
-    :note: this method is not strictly process-safe. If several servers run it concurrently, they have
-     a small chance of sampling duplicate expert uids.
-    """
-    remaining_attempts = attempts_per_expert * num_experts
-    found_uids, attempted_uids = list(), set()
-
-    def _generate_uid():
-        if expert_pattern is None:
-            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
-
-        uid = []
-        for block in expert_pattern.split(UID_DELIMITER):
-            try:
-                if "[" not in block and "]" not in block:
-                    uid.append(block)
-                elif block.startswith("[") and block.endswith("]") and ":" in block:
-                    slice_start, slice_end = map(int, block[1:-1].split(":"))
-                    uid.append(str(random.randint(slice_start, slice_end - 1)))
-                else:
-                    raise ValueError("Block must be either fixed or a range [from:to]")
-            except KeyboardInterrupt:
-                raise
-            except Exception as e:
-                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
-        return UID_DELIMITER.join(uid)
-
-    while remaining_attempts > 0 and len(found_uids) < num_experts:
-
-        # 1. sample new expert uids at random
-        new_uids = []
-        while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
-            new_uid = _generate_uid()
-            remaining_attempts -= 1
-            if new_uid not in attempted_uids:
-                attempted_uids.add(new_uid)
-                new_uids.append(new_uid)
-
-        # 2. look into DHT (if given) and remove duplicates
-        if dht:
-            existing_expert_uids = {
-                found_expert.uid
-                for found_expert in hivemind.moe.server.get_experts(dht, new_uids)
-                if found_expert is not None
-            }
-            new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
-
-        found_uids += new_uids
-
-    if len(found_uids) != num_experts:
-        logger.warning(
-            f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
-            f"{attempts_per_expert * num_experts} attempts"
-        )
-    return found_uids

+ 419 - 0
hivemind/moe/server/server.py

@@ -0,0 +1,419 @@
+from __future__ import annotations
+
+import multiprocessing as mp
+import random
+import threading
+from contextlib import contextmanager
+from functools import partial
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from multiaddr import Multiaddr
+
+from hivemind.dht import DHT
+from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
+from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts
+from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.expert_uid import UID_DELIMITER
+from hivemind.moe.server.layers import (
+    add_custom_models_from_file,
+    name_to_block,
+    name_to_input,
+    schedule_name_to_scheduler,
+)
+from hivemind.moe.server.runtime import Runtime
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.logging import get_logger
+from hivemind.utils.networking import Endpoint, get_free_port, get_port, replace_port
+from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
+
+logger = get_logger(__name__)
+
+
+class Server(threading.Thread):
+    """
+    Server allows you to host "experts" - pytorch subnetworks used by Decentralized Mixture of Experts.
+    After creation, a server should be started: see Server.run or Server.run_in_background.
+
+    A working server does two things:
+     - processes incoming forward/backward requests via Runtime (created by the server)
+     - publishes updates to expert status every :update_period: seconds
+
+    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
+     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
+    :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
+    :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
+    :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
+        if too small for normal functioning, we recommend 4 handlers per expert backend.
+    :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
+        if dht is None, this parameter is ignored.
+    :param start: if True, the server will immediately start as a background thread and returns control after server
+        is ready (see .ready below)
+    """
+
+    def __init__(
+        self,
+        dht: Optional[DHT],
+        expert_backends: Dict[str, ExpertBackend],
+        listen_on: Endpoint = "0.0.0.0:*",
+        num_connection_handlers: int = 1,
+        update_period: int = 30,
+        start=False,
+        checkpoint_dir=None,
+        **kwargs,
+    ):
+        super().__init__()
+        self.dht, self.experts, self.update_period = dht, expert_backends, update_period
+        if get_port(listen_on) is None:
+            listen_on = replace_port(listen_on, new_port=get_free_port())
+        self.listen_on, self.port = listen_on, get_port(listen_on)
+
+        self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
+        if checkpoint_dir is not None:
+            self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
+        else:
+            self.checkpoint_saver = None
+        self.runtime = Runtime(self.experts, **kwargs)
+
+        if self.dht and self.experts:
+            self.dht_handler_thread = DHTHandlerThread(
+                experts=self.experts,
+                dht=self.dht,
+                endpoint=self.listen_on,
+                update_period=self.update_period,
+                daemon=True,
+            )
+
+        if start:
+            self.run_in_background(await_ready=True)
+
+    @classmethod
+    def create(
+        cls,
+        listen_on="0.0.0.0:*",
+        num_experts: int = None,
+        expert_uids: str = None,
+        expert_pattern: str = None,
+        expert_cls="ffn",
+        hidden_dim=1024,
+        optim_cls=torch.optim.Adam,
+        scheduler: str = "none",
+        num_warmup_steps=None,
+        num_total_steps=None,
+        clip_grad_norm=None,
+        num_handlers=None,
+        min_batch_size=1,
+        max_batch_size=4096,
+        device=None,
+        no_dht=False,
+        initial_peers=(),
+        checkpoint_dir: Optional[Path] = None,
+        compression=CompressionType.NONE,
+        stats_report_interval: Optional[int] = None,
+        custom_module_path=None,
+        *,
+        start: bool,
+    ) -> Server:
+        """
+        Instantiate a server with several identical experts. See argparse comments below for details
+        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
+        :param num_experts: run this many identical experts
+        :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
+           means "sample random experts between myprefix.0.0 and myprefix.255.255;
+        :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
+        :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
+        :param hidden_dim: main dimension for expert_cls
+        :param num_handlers: server will use this many parallel processes to handle incoming requests
+        :param min_batch_size: total num examples in the same batch will be greater than this value
+        :param max_batch_size: total num examples in the same batch will not exceed this value
+        :param device: all experts will use this device in torch notation; default: cuda if available else cpu
+
+        :param optim_cls: uses this optimizer to train all experts
+        :param scheduler: if not `none`, the name of the expert LR scheduler
+        :param num_warmup_steps: the number of warmup steps for LR schedule
+        :param num_total_steps: the total number of steps for LR schedule
+        :param clip_grad_norm: maximum gradient norm used for clipping
+
+        :param no_dht: if specified, the server will not be attached to a dht
+        :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
+
+        :param checkpoint_dir: directory to save and load expert checkpoints
+
+        :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
+            hosted on this server. For a more fine-grained compression, start server in python and specify compression
+            for each BatchTensorProto in ExpertBackend for the respective experts.
+
+        :param start: if True, starts server right away and returns when server is ready for requests
+        :param stats_report_interval: interval between two reports of batch processing performance statistics
+        """
+        if custom_module_path is not None:
+            add_custom_models_from_file(custom_module_path)
+        assert expert_cls in name_to_block
+
+        if no_dht:
+            dht = None
+        else:
+            dht = DHT(initial_peers=initial_peers, start=True)
+            visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
+            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+
+        assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
+            num_experts is not None and expert_uids is None
+        ), "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
+
+        if expert_uids is None:
+            if checkpoint_dir is not None:
+                assert is_directory(checkpoint_dir)
+                expert_uids = [
+                    child.name for child in checkpoint_dir.iterdir() if (child / "checkpoint_last.pt").exists()
+                ]
+                total_experts_in_checkpoint = len(expert_uids)
+                logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
+
+                if total_experts_in_checkpoint > num_experts:
+                    raise ValueError(
+                        f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
+                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints."
+                    )
+            else:
+                expert_uids = []
+
+            uids_to_generate = num_experts - len(expert_uids)
+            if uids_to_generate > 0:
+                logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
+                expert_uids.extend(_generate_uids(uids_to_generate, expert_pattern, dht))
+
+        num_experts = len(expert_uids)
+        num_handlers = num_handlers if num_handlers is not None else num_experts * 8
+        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
+        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+
+        sample_input = name_to_input[expert_cls](DUMMY_BATCH_SIZE, hidden_dim)
+        if isinstance(sample_input, tuple):
+            args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
+        else:
+            args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
+
+        scheduler = schedule_name_to_scheduler[scheduler]
+
+        # initialize experts
+        experts = {}
+        for expert_uid in expert_uids:
+            expert = name_to_block[expert_cls](hidden_dim)
+            experts[expert_uid] = ExpertBackend(
+                name=expert_uid,
+                expert=expert,
+                args_schema=args_schema,
+                optimizer=optim_cls(expert.parameters()),
+                scheduler=scheduler,
+                num_warmup_steps=num_warmup_steps,
+                num_total_steps=num_total_steps,
+                clip_grad_norm=clip_grad_norm,
+                min_batch_size=min_batch_size,
+                max_batch_size=max_batch_size,
+            )
+
+        if checkpoint_dir is not None:
+            load_experts(experts, checkpoint_dir)
+
+        return cls(
+            dht,
+            experts,
+            listen_on=listen_on,
+            num_connection_handlers=num_handlers,
+            device=device,
+            checkpoint_dir=checkpoint_dir,
+            stats_report_interval=stats_report_interval,
+            start=start,
+        )
+
+    def run(self):
+        """
+        Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
+        runs Runtime (self.runtime) to process incoming requests.
+        """
+        logger.info(f"Server started at {self.listen_on}")
+        logger.info(f"Got {len(self.experts)} experts:")
+        for expert_name, backend in self.experts.items():
+            num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
+            logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
+
+        if self.dht:
+            if not self.dht.is_alive():
+                self.dht.run_in_background(await_ready=True)
+
+            if self.experts:
+                self.dht_handler_thread.start()
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.start()
+
+        for process in self.conn_handlers:
+            if not process.is_alive():
+                process.start()
+            process.ready.wait()
+
+        try:
+            self.runtime.run()
+        finally:
+            self.shutdown()
+
+    def run_in_background(self, await_ready=True, timeout=None):
+        """
+        Starts Server in a background thread. if await_ready, this method will wait until background server
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready and not self.ready.wait(timeout=timeout):
+            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
+
+    @property
+    def ready(self) -> mp.synchronize.Event:
+        """
+        An event (multiprocessing.Event) that is set when the server is ready to process requests.
+
+        Example
+        =======
+        >>> server.start()
+        >>> server.ready.wait(timeout=10)
+        >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
+        """
+        return self.runtime.ready  # mp.Event that is true if self is ready to process batches
+
+    def shutdown(self):
+        """
+        Gracefully terminate the server, process-safe.
+        Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
+        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
+        """
+        self.ready.clear()
+
+        for process in self.conn_handlers:
+            process.terminate()
+            process.join()
+        logger.debug("Connection handlers terminated")
+
+        if self.dht and self.experts:
+            self.dht_handler_thread.stop.set()
+            self.dht_handler_thread.join()
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.stop.set()
+            self.checkpoint_saver.join()
+
+        if self.dht is not None:
+            self.dht.shutdown()
+            self.dht.join()
+
+        logger.debug(f"Shutting down runtime")
+
+        self.runtime.shutdown()
+        logger.info("Server shutdown succesfully")
+
+
+@contextmanager
+def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[Endpoint, List[Multiaddr]]:
+    """A context manager that creates server in a background process, awaits .ready on entry and shuts down on exit"""
+    pipe, runners_pipe = mp.Pipe(duplex=True)
+    runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
+    try:
+        runner.start()
+        # once the server is ready, runner will send us
+        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
+        start_ok, data = pipe.recv()
+        if start_ok:
+            yield data
+            pipe.send("SHUTDOWN")  # on exit from context, send shutdown signal
+        else:
+            raise RuntimeError(f"Server failed to start: {data}")
+    finally:
+        runner.join(timeout=shutdown_timeout)
+        if runner.is_alive():
+            logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
+            runner.kill()
+            logger.info("Server terminated")
+
+
+def _server_runner(pipe, *args, **kwargs):
+    try:
+        server = Server.create(*args, start=True, **kwargs)
+    except Exception as e:
+        logger.exception(f"Encountered an exception when starting a server: {e}")
+        pipe.send((False, f"{type(e).__name__} {e}"))
+        return
+
+    try:
+        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
+        pipe.send((True, (server.listen_on, dht_maddrs)))
+        pipe.recv()  # wait for shutdown signal
+
+    finally:
+        logger.info("Shutting down server...")
+        server.shutdown()
+        server.join()
+        logger.info("Server shut down")
+
+
+def _generate_uids(
+    num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
+) -> List[str]:
+    """
+    Sample experts from a given pattern, remove duplicates.
+    :param num_experts: sample this many unique expert uids
+    :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
+     means "sample random experts between myprefix.0.0 and myprefix.255.255;
+    :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
+    :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
+    :note: this method is not strictly process-safe. If several servers run it concurrently, they have
+     a small chance of sampling duplicate expert uids.
+    """
+    remaining_attempts = attempts_per_expert * num_experts
+    found_uids, attempted_uids = list(), set()
+
+    def _generate_uid():
+        if expert_pattern is None:
+            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
+
+        uid = []
+        for block in expert_pattern.split(UID_DELIMITER):
+            try:
+                if "[" not in block and "]" not in block:
+                    uid.append(block)
+                elif block.startswith("[") and block.endswith("]") and ":" in block:
+                    slice_start, slice_end = map(int, block[1:-1].split(":"))
+                    uid.append(str(random.randint(slice_start, slice_end - 1)))
+                else:
+                    raise ValueError("Block must be either fixed or a range [from:to]")
+            except KeyboardInterrupt:
+                raise
+            except Exception as e:
+                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
+        return UID_DELIMITER.join(uid)
+
+    while remaining_attempts > 0 and len(found_uids) < num_experts:
+
+        # 1. sample new expert uids at random
+        new_uids = []
+        while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
+            new_uid = _generate_uid()
+            remaining_attempts -= 1
+            if new_uid not in attempted_uids:
+                attempted_uids.add(new_uid)
+                new_uids.append(new_uid)
+
+        # 2. look into DHT (if given) and remove duplicates
+        if dht is not None:
+            existing_expert_uids = {
+                found_expert.uid for found_expert in get_experts(dht, new_uids) if found_expert is not None
+            }
+            new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
+
+        found_uids += new_uids
+
+    if len(found_uids) != num_experts:
+        logger.warning(
+            f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
+            f"{attempts_per_expert * num_experts} attempts"
+        )
+    return found_uids

+ 6 - 2
hivemind/optim/optimizer.py

@@ -535,8 +535,12 @@ class Optimizer(torch.optim.Optimizer):
                 logger.exception(e)
 
         if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
-            logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
-            self._tag_along_with_zero_weight(self.scheduled_grads)
+            if self.tracker.global_progress.num_peers > 1:
+                logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
+                self._tag_along_with_zero_weight(self.scheduled_grads)
+            else:
+                logger.log(self.status_loglevel, f"Skipping pre-scheduled averaging round: there are no other peers")
+                self.scheduled_grads.cancel()
             self.scheduled_grads = None
         return began_averaging_gradients
 

+ 4 - 4
hivemind/utils/tensor_descr.py

@@ -46,11 +46,11 @@ class TensorDescriptor(DescriptorBase):
             tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, _safe_check_pinned(tensor)
         )
 
-    def make_empty(self, **kwargs):
+    def make_zeros(self, **kwargs):
         properties = asdict(self)
         properties.update(kwargs)
         properties.pop("compression")
-        return torch.empty(**properties)
+        return torch.zeros(**properties)
 
 
 def _str_to_torch_type(name: str, torch_type: type):
@@ -86,9 +86,9 @@ class BatchTensorDescriptor(TensorDescriptor):
             compression=compression if tensor.is_floating_point() else CompressionType.NONE,
         )
 
-    def make_empty(self, *batch_size: int, **kwargs) -> torch.Tensor:
+    def make_zeros(self, *batch_size: int, **kwargs) -> torch.Tensor:
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
-        return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
+        return super().make_zeros(size=(*batch_size, *self.shape[1:]), **kwargs)
 
     def packb(self) -> bytes:
         obj_dict = asdict(self)

+ 31 - 34
tests/test_moe.py

@@ -3,9 +3,12 @@ import numpy as np
 import pytest
 import torch
 
-import hivemind
-from hivemind.moe.client.expert import DUMMY
-from hivemind.moe.server import background_server, declare_experts, layers
+from hivemind.dht import DHT
+from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client.moe import DUMMY, _RemoteCallMany
+from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
+from hivemind.moe.server.layers import name_to_block
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
 @pytest.mark.forked
@@ -16,11 +19,9 @@ def test_moe():
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
     ) as (server_endpoint, dht_maddrs):
-        dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
+        dht = DHT(start=True, initial_peers=dht_maddrs)
 
-        dmoe = hivemind.RemoteMixtureOfExperts(
-            in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn."
-        )
+        dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
 
         for i in range(3):
             out = dmoe(torch.randn(10, 16))
@@ -35,9 +36,9 @@ def test_no_experts():
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
     ) as (server_endpoint, dht_maddrs):
-        dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
+        dht = DHT(start=True, initial_peers=dht_maddrs)
 
-        dmoe = hivemind.RemoteSwitchMixtureOfExperts(
+        dmoe = RemoteSwitchMixtureOfExperts(
             in_features=16,
             grid_size=(4, 4, 4),
             dht=dht,
@@ -74,10 +75,10 @@ def test_call_many(hidden_dim=16):
     ) as (server_endpoint, _):
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
-        e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
-        e5 = hivemind.RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
+        e0, e1, e2, e3, e4 = [RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
+        e5 = RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
 
-        mask, expert_outputs = hivemind.moe.client.moe._RemoteCallMany.apply(
+        mask, expert_outputs = _RemoteCallMany.apply(
             DUMMY,
             [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
             k_min,
@@ -130,8 +131,8 @@ def test_remote_module_call(hidden_dim=16):
         optim_cls=None,
         no_dht=True,
     ) as (server_endpoint, _):
-        real_expert = hivemind.RemoteExpert("expert.0", server_endpoint)
-        fake_expert = hivemind.RemoteExpert("oiasfjiasjf", server_endpoint)
+        real_expert = RemoteExpert("expert.0", server_endpoint)
+        fake_expert = RemoteExpert("oiasfjiasjf", server_endpoint)
 
         out1 = real_expert(torch.randn(1, hidden_dim))
         assert out1.shape == (1, hidden_dim)
@@ -152,12 +153,10 @@ def test_remote_module_call(hidden_dim=16):
 @pytest.mark.forked
 def test_beam_search_correctness():
     all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
-    dht = hivemind.DHT(start=True)
+    dht = DHT(start=True)
     assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
 
-    dmoe = hivemind.RemoteMixtureOfExperts(
-        in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn."
-    )
+    dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
 
     for i in range(25):
         input = torch.randn(32)
@@ -174,7 +173,7 @@ def test_beam_search_correctness():
         # reference: independently find :beam_size: best experts with exhaustive search
         all_scores = dmoe.compute_expert_scores(
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
-            [[hivemind.RemoteExpert(uid, "") for uid in all_expert_uids]],
+            [[RemoteExpert(uid, "") for uid in all_expert_uids]],
         )[0]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
 
@@ -197,7 +196,7 @@ def test_determinism(hidden_dim=16):
         optim_cls=None,
         no_dht=True,
     ) as (server_endpoint, _):
-        expert = hivemind.RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
+        expert = RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
 
         out = expert(xx, mask)
         out_rerun = expert(xx, mask)
@@ -212,8 +211,8 @@ def test_determinism(hidden_dim=16):
 @pytest.mark.forked
 def test_compute_expert_scores():
     try:
-        dht = hivemind.DHT(start=True)
-        moe = hivemind.moe.RemoteMixtureOfExperts(
+        dht = DHT(start=True)
+        moe = RemoteMixtureOfExperts(
             dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1, uid_prefix="expert."
         )
         gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
@@ -221,13 +220,11 @@ def test_compute_expert_scores():
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
             [
-                hivemind.RemoteExpert(
-                    uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337"
-                )
+                RemoteExpert(uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337")
                 for expert_i in range(len(ii[batch_i]))
             ]
             for batch_i in range(len(ii))
-        ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
+        ]  # note: these experts do not exist on server, we use them only to test compute_expert_scores
         logits = moe.compute_expert_scores([gx, gy], batch_experts)
         torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
         assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
@@ -247,25 +244,25 @@ def test_client_anomaly_detection():
 
     experts = {}
     for i in range(4):
-        expert = layers.name_to_block["ffn"](HID_DIM)
-        experts[f"expert.{i}"] = hivemind.ExpertBackend(
+        expert = name_to_block["ffn"](HID_DIM)
+        experts[f"expert.{i}"] = ExpertBackend(
             name=f"expert.{i}",
             expert=expert,
             optimizer=torch.optim.Adam(expert.parameters()),
-            args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
-            outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
+            args_schema=(BatchTensorDescriptor(HID_DIM),),
+            outputs_schema=BatchTensorDescriptor(HID_DIM),
             max_batch_size=16,
         )
 
     experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan")
 
-    dht = hivemind.DHT(start=True)
-    server = hivemind.moe.Server(dht, experts, num_connection_handlers=1)
+    dht = DHT(start=True)
+    server = Server(dht, experts, num_connection_handlers=1)
     server.start()
     try:
         server.ready.wait()
 
-        dmoe = hivemind.RemoteMixtureOfExperts(
+        dmoe = RemoteMixtureOfExperts(
             in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
         )
 
@@ -282,7 +279,7 @@ def test_client_anomaly_detection():
         with pytest.raises(ValueError):
             inf_loss.backward()
 
-        dmoe = hivemind.RemoteMixtureOfExperts(
+        dmoe = RemoteMixtureOfExperts(
             in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
         )
         output = dmoe(input)