Browse Source

Simplify ExpertBackend interface (#483)

- extract gradient clipping from ExpertBackend: this behavior can be achieved with a user-defined Optimizer
- remove stats from ExpertBackend: this behavior can be achieved with a user-defined Scheduler
- rename full_state -> state_dict, rationale: there is no "non-full" state in this context
- rename ExpertBackend.expert -> ExpertBackend.module to avoid confusion

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 3 years ago
parent
commit
5ea21a75f4

+ 5 - 5
benchmarks/benchmark_throughput.py

@@ -10,7 +10,7 @@ from hivemind.dht import DHT
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.expert_uid import ExpertInfo
-from hivemind.moe.server import ExpertBackend, Server
+from hivemind.moe.server import ModuleBackend, Server
 from hivemind.moe.server.layers import name_to_block
 from hivemind.moe.server.layers import name_to_block
 from hivemind.p2p import P2P
 from hivemind.p2p import P2P
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
@@ -118,12 +118,12 @@ def benchmark_throughput(
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
 
 
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-        experts = {}
+        module_backends = {}
         for i in range(num_experts):
         for i in range(num_experts):
             expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
             expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
-            experts[f"expert.{i}"] = ExpertBackend(
+            module_backends[f"expert.{i}"] = ModuleBackend(
                 name=f"expert.{i}",
                 name=f"expert.{i}",
-                expert=expert,
+                module=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
                 optimizer=torch.optim.Adam(expert.parameters()),
                 args_schema=(BatchTensorDescriptor(hid_dim),),
                 args_schema=(BatchTensorDescriptor(hid_dim),),
                 outputs_schema=BatchTensorDescriptor(hid_dim),
                 outputs_schema=BatchTensorDescriptor(hid_dim),
@@ -133,7 +133,7 @@ def benchmark_throughput(
 
 
         server = Server(
         server = Server(
             dht=server_dht,
             dht=server_dht,
-            expert_backends=experts,
+            module_backends=module_backends,
             num_connection_handlers=num_handlers,
             num_connection_handlers=num_handlers,
             device=device,
             device=device,
         )
         )

+ 5 - 5
docs/modules/server.rst

@@ -9,9 +9,9 @@ or as a part of **hivemind.moe.client.RemoteMixtureOfExperts** that finds the mo
 The hivemind.moe.server module is organized as follows:
 The hivemind.moe.server module is organized as follows:
 
 
 - Server_ is the main class that publishes experts, accepts incoming requests, and passes them to Runtime_ for compute.
 - Server_ is the main class that publishes experts, accepts incoming requests, and passes them to Runtime_ for compute.
-- ExpertBackend_ is a wrapper for `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_ \
+- ModuleBackend_ is a wrapper for `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_ \
   that can be accessed by remote clients. It has two TaskPool_ s for forward and backward requests.
   that can be accessed by remote clients. It has two TaskPool_ s for forward and backward requests.
-- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert.
+- Runtime_ balances the device (GPU) usage between several ModuleBackend_ instances that each service one expert.
 - TaskPool_ stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches \
 - TaskPool_ stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches \
   and offers those batches to Runtime_ for processing.
   and offers those batches to Runtime_ for processing.
 
 
@@ -25,9 +25,9 @@ The hivemind.moe.server module is organized as follows:
    :members:
    :members:
    :member-order: bysource
    :member-order: bysource
 
 
-.. _ExpertBackend:
-.. autoclass:: ExpertBackend
-    :members: forward, backward, apply_gradients, get_info, get_pools
+.. _ModuleBackend:
+.. autoclass:: ModuleBackend
+    :members: forward, backward, on_backward, get_info, get_pools
     :member-order: bysource
     :member-order: bysource
 
 
 .. currentmodule:: hivemind.moe.server.runtime
 .. currentmodule:: hivemind.moe.server.runtime

+ 1 - 1
hivemind/__init__.py

@@ -2,7 +2,7 @@ from hivemind.averaging import DecentralizedAverager
 from hivemind.compression import *
 from hivemind.compression import *
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe import (
 from hivemind.moe import (
-    ExpertBackend,
+    ModuleBackend,
     RemoteExpert,
     RemoteExpert,
     RemoteMixtureOfExperts,
     RemoteMixtureOfExperts,
     RemoteSwitchMixtureOfExperts,
     RemoteSwitchMixtureOfExperts,

+ 2 - 1
hivemind/hivemind_cli/run_server.py

@@ -54,7 +54,8 @@ def main():
                         help='Server will report experts to DHT once in this many seconds')
                         help='Server will report experts to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,
     parser.add_argument('--expiration', type=float, required=False, default=None,
                         help='DHT entries will expire after this many seconds')
                         help='DHT entries will expire after this many seconds')
-    parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
+    parser.add_argument('--num_training_steps', type=int, required=False, help='The total number of steps for LR schedule')
+
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
 
 
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],

+ 1 - 1
hivemind/moe/__init__.py

@@ -1,6 +1,6 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import (
 from hivemind.moe.server import (
-    ExpertBackend,
+    ModuleBackend,
     Server,
     Server,
     background_server,
     background_server,
     declare_experts,
     declare_experts,

+ 1 - 1
hivemind/moe/server/__init__.py

@@ -1,4 +1,4 @@
 from hivemind.moe.server.dht_handler import declare_experts, get_experts
 from hivemind.moe.server.dht_handler import declare_experts, get_experts
-from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.layers import register_expert_class
 from hivemind.moe.server.layers import register_expert_class
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.server import Server, background_server
 from hivemind.moe.server.server import Server, background_server

+ 9 - 9
hivemind/moe/server/checkpoints.py

@@ -8,7 +8,7 @@ from typing import Dict
 
 
 import torch
 import torch
 
 
-from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -34,23 +34,23 @@ def copy_tree(src: str, dst: str):
 
 
 
 
 class CheckpointSaver(threading.Thread):
 class CheckpointSaver(threading.Thread):
-    def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: float):
+    def __init__(self, module_backends: Dict[str, ModuleBackend], checkpoint_dir: Path, update_period: float):
         super().__init__()
         super().__init__()
         assert is_directory(checkpoint_dir)
         assert is_directory(checkpoint_dir)
-        self.expert_backends = expert_backends
+        self.module_backends = module_backends
         self.update_period = update_period
         self.update_period = update_period
         self.checkpoint_dir = checkpoint_dir
         self.checkpoint_dir = checkpoint_dir
         self.stop = threading.Event()
         self.stop = threading.Event()
 
 
         # create expert directories to ensure that the directory is writable and checkpoints can be loaded
         # create expert directories to ensure that the directory is writable and checkpoints can be loaded
-        store_experts(self.expert_backends, self.checkpoint_dir)
+        store_experts(self.module_backends, self.checkpoint_dir)
 
 
     def run(self) -> None:
     def run(self) -> None:
         while not self.stop.wait(self.update_period):
         while not self.stop.wait(self.update_period):
-            store_experts(self.expert_backends, self.checkpoint_dir)
+            store_experts(self.module_backends, self.checkpoint_dir)
 
 
 
 
-def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
+def store_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path):
     logger.debug(f"Storing experts at {checkpoint_dir.absolute()}")
     logger.debug(f"Storing experts at {checkpoint_dir.absolute()}")
     assert is_directory(checkpoint_dir)
     assert is_directory(checkpoint_dir)
     timestamp = datetime.now().isoformat(sep="_")
     timestamp = datetime.now().isoformat(sep="_")
@@ -59,17 +59,17 @@ def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
             expert_dir = Path(tmpdirname) / expert_name
             expert_dir = Path(tmpdirname) / expert_name
             expert_dir.mkdir()
             expert_dir.mkdir()
             checkpoint_name = expert_dir / f"checkpoint_{timestamp}.pt"
             checkpoint_name = expert_dir / f"checkpoint_{timestamp}.pt"
-            torch.save(expert_backend.get_full_state(), checkpoint_name)
+            torch.save(expert_backend.state_dict(), checkpoint_name)
             os.symlink(checkpoint_name, expert_dir / "checkpoint_last.pt")
             os.symlink(checkpoint_name, expert_dir / "checkpoint_last.pt")
         copy_tree(tmpdirname, str(checkpoint_dir))
         copy_tree(tmpdirname, str(checkpoint_dir))
 
 
 
 
-def load_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
+def load_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path):
     assert is_directory(checkpoint_dir)
     assert is_directory(checkpoint_dir)
     for expert_name, expert in experts.items():
     for expert_name, expert in experts.items():
         checkpoints_folder = checkpoint_dir / expert_name
         checkpoints_folder = checkpoint_dir / expert_name
         latest_checkpoint = checkpoints_folder / "checkpoint_last.pt"
         latest_checkpoint = checkpoints_folder / "checkpoint_last.pt"
         if latest_checkpoint.exists():
         if latest_checkpoint.exists():
-            expert.load_full_state(torch.load(latest_checkpoint))
+            expert.load_state_dict(torch.load(latest_checkpoint))
         else:
         else:
             logger.warning(f"Failed to load checkpoint for expert {expert_name}")
             logger.warning(f"Failed to load checkpoint for expert {expert_name}")

+ 3 - 3
hivemind/moe/server/connection_handler.py

@@ -6,7 +6,7 @@ import torch
 
 
 from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE, P2P
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE, P2P
@@ -25,10 +25,10 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
 
 
     :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port
     :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port
     :param dht: a running hivemind.dht.DHT, used to let other peers connect to this one
     :param dht: a running hivemind.dht.DHT, used to let other peers connect to this one
-    :param experts: a dict [UID -> ExpertBackend] with all active experts
+    :param experts: a dict [UID -> ModuleBackend] with all active experts
     """
     """
 
 
-    def __init__(self, dht: DHT, experts: Dict[str, ExpertBackend]):
+    def __init__(self, dht: DHT, experts: Dict[str, ModuleBackend]):
         super().__init__()
         super().__init__()
         self.dht, self.experts = dht, experts
         self.dht, self.experts = dht, experts
         self._p2p: Optional[P2P] = None
         self._p2p: Optional[P2P] = None

+ 6 - 4
hivemind/moe/server/dht_handler.py

@@ -20,20 +20,22 @@ from hivemind.utils import MAX_DHT_TIME_DISCREPANCY_SECONDS, MPFuture, get_dht_t
 
 
 
 
 class DHTHandlerThread(threading.Thread):
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs):
+    def __init__(
+        self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
+    ):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
         if expiration is None:
         if expiration is None:
             expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
             expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
-        self.experts = experts
+        self.module_backends = module_backends
         self.dht = dht
         self.dht = dht
         self.update_period = update_period
         self.update_period = update_period
         self.expiration = expiration
         self.expiration = expiration
         self.stop = threading.Event()
         self.stop = threading.Event()
 
 
     def run(self) -> None:
     def run(self) -> None:
-        declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration)
+        declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration)
         while not self.stop.wait(self.update_period):
         while not self.stop.wait(self.update_period):
-            declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration)
+            declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration)
 
 
 
 
 def declare_experts(
 def declare_experts(

+ 1 - 1
hivemind/moe/server/layers/dropout.py

@@ -19,7 +19,7 @@ class DeterministicDropoutFunction(torch.autograd.Function):
 class DeterministicDropout(nn.Module):
 class DeterministicDropout(nn.Module):
     """
     """
     Custom dropout layer which accepts dropout mask as an input (drop_prob is only used for scaling input activations).
     Custom dropout layer which accepts dropout mask as an input (drop_prob is only used for scaling input activations).
-    Can be used with RemoteExpert/ExpertBackend to ensure that dropout mask is the same at forward and backward steps
+    Can be used with RemoteExpert/ModuleBackend to ensure that dropout mask is the same at forward and backward steps
     """
     """
 
 
     def __init__(self, drop_prob):
     def __init__(self, drop_prob):

+ 58 - 0
hivemind/moe/server/layers/optim.py

@@ -0,0 +1,58 @@
+import torch
+
+
+class OptimizerWrapper(torch.optim.Optimizer):
+    """A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer"""
+
+    def __init__(self, optim: torch.optim.Optimizer):
+        super().__init__(optim.param_groups, optim.defaults)
+        self.optim = optim
+
+    @property
+    def defaults(self):
+        return self.optim.defaults
+
+    @property
+    def state(self):
+        return self.optim.state
+
+    def __getstate__(self):
+        return self.optim.__getstate__()
+
+    def __setstate__(self, state):
+        self.optim.__setstate__(state)
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({repr(self.optim)})"
+
+    def state_dict(self):
+        return self.optim.state_dict()
+
+    def load_state_dict(self, state_dict: dict) -> None:
+        return self.optim.load_state_dict(state_dict)
+
+    def step(self, *args, **kwargs):
+        return self.optim.step(*args, **kwargs)
+
+    def zero_grad(self, *args, **kwargs):
+        return self.optim.zero_grad(*args, **kwargs)
+
+    @property
+    def param_groups(self):
+        return self.optim.param_groups
+
+    def add_param_group(self, param_group: dict) -> None:
+        return self.optim.add_param_group(param_group)
+
+
+class ClippingWrapper(OptimizerWrapper):
+    """A wrapper of torch.Optimizer that clips gradients by global norm before each step"""
+
+    def __init__(self, optim: torch.optim.Optimizer, clip_grad_norm: float):
+        super().__init__(optim)
+        self.clip_grad_norm = clip_grad_norm
+
+    def step(self, *args, **kwargs):
+        parameters = tuple(param for group in self.param_groups for param in group["params"])
+        torch.nn.utils.clip_grad_norm_(parameters, self.clip_grad_norm)
+        return super().step(*args, **kwargs)

+ 46 - 84
hivemind/moe/server/expert_backend.py → hivemind/moe/server/module_backend.py

@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Sequence, Tuple, Union
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
 from torch import nn
 from torch import nn
@@ -8,19 +8,20 @@ from hivemind.utils.logging import get_logger
 from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
 from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
 from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 
 
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class ExpertBackend:
+class ModuleBackend:
     """
     """
-    ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime
-    By default, ExpertBackend handles three types of requests:
+    ModuleBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime
+    By default, ModuleBackend handles three types of requests:
 
 
      - forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization.
      - forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization.
      - backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched.
      - backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched.
      - get_info - return expert metadata. Not batched.
      - get_info - return expert metadata. Not batched.
 
 
-    :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
+    :param module: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
 
 
      - Experts must always receive the same set of args and kwargs and produce output tensors of same type
      - Experts must always receive the same set of args and kwargs and produce output tensors of same type
      - All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
      - All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
@@ -34,49 +35,37 @@ class ExpertBackend:
     :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
     :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
     :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
     :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
     :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
     :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
-    :param num_warmup_steps: the number of warmup steps for LR schedule
-    :param num_total_steps: the total number of steps for LR schedule
-    :param clip_grad_norm: maximum gradient norm used for clipping
     :param kwargs: extra parameters to be forwarded into TaskPool.__init__
     :param kwargs: extra parameters to be forwarded into TaskPool.__init__
     """
     """
 
 
     def __init__(
     def __init__(
         self,
         self,
         name: str,
         name: str,
-        expert: nn.Module,
-        optimizer: torch.optim.Optimizer,
+        module: nn.Module,
         *,
         *,
-        scheduler: Callable = None,
+        optimizer: Optional[torch.optim.Optimizer] = None,
+        scheduler: Optional[LRSchedulerBase] = None,
         args_schema: Tuple[BatchTensorDescriptor, ...] = None,
         args_schema: Tuple[BatchTensorDescriptor, ...] = None,
         kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
         kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
         outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
         outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
-        num_warmup_steps: int = None,
-        num_total_steps: int = None,
-        clip_grad_norm: float = None,
         **kwargs,
         **kwargs,
     ):
     ):
         super().__init__()
         super().__init__()
-        self.expert, self.optimizer, self.name = expert, optimizer, name
-
-        if scheduler is None:
-            self.scheduler = None
-        else:
-            assert optimizer is not None and num_warmup_steps is not None and num_total_steps is not None
-            self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_total_steps)
-        self.clip_grad_norm = clip_grad_norm
+        self.name, self.module, self.optimizer, self.scheduler = name, module, optimizer, scheduler
 
 
         self.args_schema = args_schema = tuple(args_schema or ())
         self.args_schema = args_schema = tuple(args_schema or ())
         self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
         self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
         assert args_schema or kwargs_schema, (
         assert args_schema or kwargs_schema, (
-            "expert must receive at least one positional or keyword input."
+            f"Module must take at least one positional or keyword input."
             " Did you forget to provide args_schema/kwargs_schema?"
             " Did you forget to provide args_schema/kwargs_schema?"
         )
         )
+        assert optimizer is not None or scheduler is None, "scheduler should only be used if optimizer is not None"
 
 
         if outputs_schema is None:
         if outputs_schema is None:
             # run expert once to get outputs schema
             # run expert once to get outputs schema
             dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema)
             dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema)
             dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
             dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
-            dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
+            dummy_outputs = self.module(*dummy_args, **dummy_kwargs)
             outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
             outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
 
 
         self.forward_schema = (self.args_schema, self.kwargs_schema)  # inputs for forward
         self.forward_schema = (self.args_schema, self.kwargs_schema)  # inputs for forward
@@ -87,22 +76,17 @@ class ExpertBackend:
         self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
         self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
         self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
         self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
 
 
-        self.update_count = 0
-        self.examples_processed = 0
-
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         """
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
-        To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``.
+        To submit a request for asynchronous processing, please use ``ModuleBackend.forward_pool.submit_task``.
+
+        .. warning: if the underlying module performs non-gradient updates (e.g. batchnorm), it will be updated twice:
+           once during forward pass, and again during backward. This behavior is similar to gradient checkpointing.
 
 
         Subclassing:
         Subclassing:
            This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``;
            This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``;
-
            It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
            It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
-
-           .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice.
-           .. For now, either register all buffers as outputs or avoid stateful experts
-
         """
         """
         args, kwargs = nested_pack(inputs, structure=self.forward_schema)
         args, kwargs = nested_pack(inputs, structure=self.forward_schema)
 
 
@@ -110,7 +94,7 @@ class ExpertBackend:
             raise RuntimeError("Batch should contain more than 0 samples")
             raise RuntimeError("Batch should contain more than 0 samples")
 
 
         with torch.no_grad():
         with torch.no_grad():
-            outputs = self.expert(*args, **kwargs)
+            outputs = self.module(*args, **kwargs)
 
 
         # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
         # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
         return tuple(nested_flatten(outputs))
         return tuple(nested_flatten(outputs))
@@ -118,7 +102,7 @@ class ExpertBackend:
     def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
     def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         """
         Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually
         Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually
-        To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``.
+        To submit a request for asynchronous processing, please use ``ModuleBackend.backward_pool.submit_task``.
 
 
         Subclassing:
         Subclassing:
            This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``;
            This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``;
@@ -128,9 +112,7 @@ class ExpertBackend:
            Runtime doesn't guarantee that backward will be performed in the same order and for the same data
            Runtime doesn't guarantee that backward will be performed in the same order and for the same data
            as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
            as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
 
 
-           .. todo correct state handling (see forward)
-
-           Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train
+           Please make sure to call ``ModuleBackend.on_backward`` after each call to backward
         """
         """
         (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
         (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
 
 
@@ -148,7 +130,7 @@ class ExpertBackend:
 
 
             batch_size = args[0].size(0)
             batch_size = args[0].size(0)
 
 
-            outputs = self.expert(*args, **kwargs)
+            outputs = self.module(*args, **kwargs)
             assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
             assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
 
 
             outputs_flat = tuple(nested_flatten(outputs))
             outputs_flat = tuple(nested_flatten(outputs))
@@ -163,65 +145,45 @@ class ExpertBackend:
             torch.autograd.backward(
             torch.autograd.backward(
                 outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False
                 outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False
             )
             )
-            self.apply_gradients(batch_size)
+            self.on_backward(batch_size)
 
 
         return tuple(
         return tuple(
             x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs))
             x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs))
         )
         )
 
 
-    def apply_gradients(self, batch_size) -> None:
+    def on_backward(self, batch_size: int) -> None:
         """
         """
-        Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
+        Train the expert for one step. This method is called by ``ModuleBackend.backward`` after computing gradients.
         """
         """
-        if self.clip_grad_norm is not None:
-            torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm)
-
-        self.optimizer.step()
-        self.optimizer.zero_grad()
+        if self.optimizer is not None:
+            self.optimizer.step()
+            self.optimizer.zero_grad()
 
 
         if self.scheduler is not None:
         if self.scheduler is not None:
             self.scheduler.step()
             self.scheduler.step()
 
 
-        self.update_count += 1
-        self.examples_processed += batch_size
-
-    def get_stats(self) -> Dict:
-        """
-        Return current expert training statistics (number of updates, number of processed examples after
-        last optimizer step)
-        """
-        return {"updates": self.update_count, "examples_processed": self.examples_processed}
-
-    def get_full_state(self) -> Dict:
-        """
-        Return the current state of the expert (including batch processing statistics)
-        """
-        full_state = {
-            "stats": self.get_stats(),
-            "model": self.expert.state_dict(),
-            "optimizer": self.optimizer.state_dict(),
-            "scheduler": {} if self.scheduler is None else self.scheduler.state_dict(),
-        }
+    def state_dict(self) -> Dict:
+        """Return the current state of the module, optimizer, and scheduler"""
+        full_state = dict(module=self.module.state_dict())
+        if self.optimizer is not None:
+            full_state["optimizer"] = self.optimizer.state_dict()
+        if self.scheduler is not None:
+            full_state["scheduler"] = self.scheduler.state_dict()
         return full_state
         return full_state
 
 
-    def load_full_state(self, state_dict: Dict):
-        if "stats" in state_dict:
-            self.update_count = state_dict["stats"]["updates"]
-            self.examples_processed = state_dict["stats"]["examples_processed"]
-        else:
-            logger.warning(f"Batch processing stats missing for expert {self.name}")
-
-        self.expert.load_state_dict(state_dict["model"])
+    def load_state_dict(self, state_dict: Dict):
+        self.module.load_state_dict(state_dict["module"])
+        if self.optimizer is not None:
+            if "optimizer" in state_dict:
+                self.optimizer.load_state_dict(state_dict["optimizer"])
+            else:
+                logger.warning(f"Optimizer state missing for {self.name}")
 
 
-        if "optimizer" in state_dict:
-            self.optimizer.load_state_dict(state_dict["optimizer"])
-        else:
-            logger.warning(f"Optimizer state missing for expert {self.name}")
-
-        if self.scheduler is not None and "scheduler" in state_dict:
-            self.scheduler.load_state_dict(state_dict["scheduler"])
-        else:
-            logger.warning(f"Learning rate scheduler state missing for expert {self.name}")
+        if self.scheduler is not None:
+            if "scheduler" in state_dict:
+                self.scheduler.load_state_dict(state_dict["scheduler"])
+            else:
+                logger.warning(f"Learning rate scheduler state missing for {self.name}")
 
 
     def get_info(self) -> Dict[str, Any]:
     def get_info(self) -> Dict[str, Any]:
         """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
         """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""

+ 11 - 11
hivemind/moe/server/runtime.py

@@ -12,7 +12,7 @@ from typing import Dict, NamedTuple, Optional
 import torch
 import torch
 from prefetch_generator import BackgroundGenerator
 from prefetch_generator import BackgroundGenerator
 
 
-from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -20,20 +20,20 @@ logger = get_logger(__name__)
 
 
 class Runtime(threading.Thread):
 class Runtime(threading.Thread):
     """
     """
-    A group of processes that processes incoming requests for multiple experts on a shared device.
+    A group of processes that processes incoming requests for multiple module backends on a shared device.
     Runtime is usually created and managed by Server, humans need not apply.
     Runtime is usually created and managed by Server, humans need not apply.
 
 
     For debugging, you can start runtime manually with .start() or .run()
     For debugging, you can start runtime manually with .start() or .run()
 
 
-    >>> expert_backends = {'expert_name': ExpertBackend(**kwargs)}
-    >>> runtime = Runtime(expert_backends)
+    >>> module_backends = {'expert_name': ModuleBackend(**kwargs)}
+    >>> runtime = Runtime(module_backends)
     >>> runtime.start()  # start runtime in background thread. To start in current thread, use runtime.run()
     >>> runtime.start()  # start runtime in background thread. To start in current thread, use runtime.run()
     >>> runtime.ready.wait()  # await for runtime to load all experts on device and create request pools
     >>> runtime.ready.wait()  # await for runtime to load all experts on device and create request pools
-    >>> future = runtime.expert_backends['expert_name'].forward_pool.submit_task(*expert_inputs)
+    >>> future = runtime.module_backends['expert_name'].forward_pool.submit_task(*module_inputs)
     >>> print("Returned:", future.result())
     >>> print("Returned:", future.result())
     >>> runtime.shutdown()
     >>> runtime.shutdown()
 
 
-    :param expert_backends: a dict [expert uid -> ExpertBackend]
+    :param module_backends: a dict [expert uid -> ModuleBackend]
     :param prefetch_batches: form up to this many batches in advance
     :param prefetch_batches: form up to this many batches in advance
     :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
     :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
     :param device: if specified, moves all experts and data to this device via .to(device=device).
     :param device: if specified, moves all experts and data to this device via .to(device=device).
@@ -46,15 +46,15 @@ class Runtime(threading.Thread):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        expert_backends: Dict[str, ExpertBackend],
+        module_backends: Dict[str, ModuleBackend],
         prefetch_batches=64,
         prefetch_batches=64,
         sender_threads: int = 1,
         sender_threads: int = 1,
         device: torch.device = None,
         device: torch.device = None,
         stats_report_interval: Optional[int] = None,
         stats_report_interval: Optional[int] = None,
     ):
     ):
         super().__init__()
         super().__init__()
-        self.expert_backends = expert_backends
-        self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
+        self.module_backends = module_backends
+        self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
         self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
         self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
         self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
         self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
         self.shutdown_trigger = mp.Event()
         self.shutdown_trigger = mp.Event()
@@ -69,8 +69,8 @@ class Runtime(threading.Thread):
             if not pool.is_alive():
             if not pool.is_alive():
                 pool.start()
                 pool.start()
         if self.device is not None:
         if self.device is not None:
-            for expert_backend in self.expert_backends.values():
-                expert_backend.expert.to(self.device)
+            for backend in self.module_backends.values():
+                backend.module.to(self.device)
 
 
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
             try:
             try:

+ 33 - 28
hivemind/moe/server/server.py

@@ -15,13 +15,14 @@ from hivemind.moe.expert_uid import UID_DELIMITER
 from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
 from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts
 from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts
-from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.layers import (
 from hivemind.moe.server.layers import (
     add_custom_models_from_file,
     add_custom_models_from_file,
     name_to_block,
     name_to_block,
     name_to_input,
     name_to_input,
     schedule_name_to_scheduler,
     schedule_name_to_scheduler,
 )
 )
+from hivemind.moe.server.layers.optim import ClippingWrapper
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.runtime import Runtime
 from hivemind.moe.server.runtime import Runtime
 from hivemind.p2p import PeerInfo
 from hivemind.p2p import PeerInfo
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
@@ -33,7 +34,7 @@ logger = get_logger(__name__)
 
 
 class Server(threading.Thread):
 class Server(threading.Thread):
     """
     """
-    Server allows you to host "experts" - pytorch subnetworks used by Decentralized Mixture of Experts.
+    Server allows you to host "experts" - pytorch subnetworks that can be accessed remotely by peers.
     After creation, a server should be started: see Server.run or Server.run_in_background.
     After creation, a server should be started: see Server.run or Server.run_in_background.
 
 
     A working server does two things:
     A working server does two things:
@@ -41,7 +42,7 @@ class Server(threading.Thread):
      - publishes updates to expert status every :update_period: seconds
      - publishes updates to expert status every :update_period: seconds
 
 
     :type dht: an instance of hivemind.DHT. Server will use DHT for all network interactions.
     :type dht: an instance of hivemind.DHT. Server will use DHT for all network interactions.
-    :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
+    :param module_backends: dict{expert uid (str) : ModuleBackend} for all expert hosted by this server.
     :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
     :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
         if too small for normal functioning, we recommend 4 handlers per expert backend.
         if too small for normal functioning, we recommend 4 handlers per expert backend.
     :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
     :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
@@ -54,7 +55,7 @@ class Server(threading.Thread):
     def __init__(
     def __init__(
         self,
         self,
         dht: DHT,
         dht: DHT,
-        expert_backends: Dict[str, ExpertBackend],
+        module_backends: Dict[str, ModuleBackend],
         num_connection_handlers: int = 1,
         num_connection_handlers: int = 1,
         update_period: float = 30,
         update_period: float = 30,
         expiration: Optional[float] = None,
         expiration: Optional[float] = None,
@@ -63,18 +64,18 @@ class Server(threading.Thread):
         **kwargs,
         **kwargs,
     ):
     ):
         super().__init__()
         super().__init__()
-        self.dht, self.experts, self.update_period = dht, expert_backends, update_period
+        self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
 
 
-        self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(num_connection_handlers)]
+        self.conn_handlers = [ConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
         if checkpoint_dir is not None:
         if checkpoint_dir is not None:
-            self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
+            self.checkpoint_saver = CheckpointSaver(module_backends, checkpoint_dir, update_period)
         else:
         else:
             self.checkpoint_saver = None
             self.checkpoint_saver = None
-        self.runtime = Runtime(self.experts, **kwargs)
+        self.runtime = Runtime(self.module_backends, **kwargs)
 
 
-        if self.experts:
+        if self.module_backends:
             self.dht_handler_thread = DHTHandlerThread(
             self.dht_handler_thread = DHTHandlerThread(
-                experts=self.experts,
+                module_backends=self.module_backends,
                 dht=self.dht,
                 dht=self.dht,
                 update_period=self.update_period,
                 update_period=self.update_period,
                 expiration=expiration,
                 expiration=expiration,
@@ -95,7 +96,7 @@ class Server(threading.Thread):
         optim_cls=torch.optim.Adam,
         optim_cls=torch.optim.Adam,
         scheduler: str = "none",
         scheduler: str = "none",
         num_warmup_steps=None,
         num_warmup_steps=None,
-        num_total_steps=None,
+        num_training_steps=None,
         clip_grad_norm=None,
         clip_grad_norm=None,
         num_handlers=None,
         num_handlers=None,
         min_batch_size=1,
         min_batch_size=1,
@@ -113,7 +114,7 @@ class Server(threading.Thread):
         **kwargs,
         **kwargs,
     ) -> Server:
     ) -> Server:
         """
         """
-        Instantiate a server with several identical experts. See argparse comments below for details
+        Instantiate a server with several identical modules. See argparse comments below for details
 
 
         :param num_experts: run this many identical experts
         :param num_experts: run this many identical experts
         :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
         :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
@@ -129,7 +130,7 @@ class Server(threading.Thread):
         :param optim_cls: uses this optimizer to train all experts
         :param optim_cls: uses this optimizer to train all experts
         :param scheduler: if not `none`, the name of the expert LR scheduler
         :param scheduler: if not `none`, the name of the expert LR scheduler
         :param num_warmup_steps: the number of warmup steps for LR schedule
         :param num_warmup_steps: the number of warmup steps for LR schedule
-        :param num_total_steps: the total number of steps for LR schedule
+        :param num_training_steps: the total number of steps for LR schedule
         :param clip_grad_norm: maximum gradient norm used for clipping
         :param clip_grad_norm: maximum gradient norm used for clipping
 
 
         :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
         :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
@@ -138,7 +139,7 @@ class Server(threading.Thread):
 
 
         :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
         :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
             hosted on this server. For a more fine-grained compression, start server in python and specify compression
             hosted on this server. For a more fine-grained compression, start server in python and specify compression
-            for each BatchTensorProto in ExpertBackend for the respective experts.
+            for each BatchTensorProto in ModuleBackend for the respective experts.
 
 
         :param start: if True, starts server right away and returns when server is ready for requests
         :param start: if True, starts server right away and returns when server is ready for requests
         :param stats_report_interval: interval between two reports of batch processing performance statistics
         :param stats_report_interval: interval between two reports of batch processing performance statistics
@@ -180,7 +181,6 @@ class Server(threading.Thread):
 
 
         num_experts = len(expert_uids)
         num_experts = len(expert_uids)
         num_handlers = num_handlers if num_handlers is not None else num_experts * 8
         num_handlers = num_handlers if num_handlers is not None else num_experts * 8
-        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 
 
         sample_input = name_to_input[expert_cls](DUMMY_BATCH_SIZE, hidden_dim)
         sample_input = name_to_input[expert_cls](DUMMY_BATCH_SIZE, hidden_dim)
@@ -189,21 +189,26 @@ class Server(threading.Thread):
         else:
         else:
             args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
             args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
 
 
-        scheduler = schedule_name_to_scheduler[scheduler]
+        scheduler_cls = schedule_name_to_scheduler[scheduler]
+        if scheduler_cls is not None:
+            scheduler_cls = partial(
+                scheduler_cls, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
+            )
 
 
         # initialize experts
         # initialize experts
         experts = {}
         experts = {}
         for expert_uid in expert_uids:
         for expert_uid in expert_uids:
             expert = name_to_block[expert_cls](hidden_dim)
             expert = name_to_block[expert_cls](hidden_dim)
-            experts[expert_uid] = ExpertBackend(
+            optimizer = optim_cls(expert.parameters()) if optim_cls is not None else None
+            scheduler = scheduler_cls(optimizer) if scheduler_cls is not None else None
+            if clip_grad_norm is not None:
+                optimizer = ClippingWrapper(optimizer, clip_grad_norm)
+            experts[expert_uid] = ModuleBackend(
                 name=expert_uid,
                 name=expert_uid,
-                expert=expert,
+                module=expert,
                 args_schema=args_schema,
                 args_schema=args_schema,
-                optimizer=optim_cls(expert.parameters()),
+                optimizer=optimizer,
                 scheduler=scheduler,
                 scheduler=scheduler,
-                num_warmup_steps=num_warmup_steps,
-                num_total_steps=num_total_steps,
-                clip_grad_norm=clip_grad_norm,
                 min_batch_size=min_batch_size,
                 min_batch_size=min_batch_size,
                 max_batch_size=max_batch_size,
                 max_batch_size=max_batch_size,
             )
             )
@@ -228,15 +233,15 @@ class Server(threading.Thread):
         Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
         Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
         runs Runtime (self.runtime) to process incoming requests.
         runs Runtime (self.runtime) to process incoming requests.
         """
         """
-        logger.info(f"Server started with {len(self.experts)} experts:")
-        for expert_name, backend in self.experts.items():
-            num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
-            logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
+        logger.info(f"Server started with {len(self.module_backends)} modules:")
+        for expert_name, backend in self.module_backends.items():
+            num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
+            logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
 
 
         if not self.dht.is_alive():
         if not self.dht.is_alive():
             self.dht.run_in_background(await_ready=True)
             self.dht.run_in_background(await_ready=True)
 
 
-        if self.experts:
+        if self.module_backends:
             self.dht_handler_thread.start()
             self.dht_handler_thread.start()
 
 
         if self.checkpoint_saver is not None:
         if self.checkpoint_saver is not None:
@@ -287,7 +292,7 @@ class Server(threading.Thread):
             process.join()
             process.join()
         logger.debug("Connection handlers terminated")
         logger.debug("Connection handlers terminated")
 
 
-        if self.experts:
+        if self.module_backends:
             self.dht_handler_thread.stop.set()
             self.dht_handler_thread.stop.set()
             self.dht_handler_thread.join()
             self.dht_handler_thread.join()
 
 

+ 5 - 5
tests/test_connection_handler.py

@@ -10,7 +10,7 @@ import torch
 from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.connection_handler import ConnectionHandler
-from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PHandlerError
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
@@ -25,7 +25,7 @@ from hivemind.utils.tensor_descr import BatchTensorDescriptor
 async def test_connection_handler_info():
 async def test_connection_handler_info():
     handler = ConnectionHandler(
     handler = ConnectionHandler(
         DHT(start=True),
         DHT(start=True),
-        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
+        dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
     )
     )
     handler.start()
     handler.start()
 
 
@@ -48,7 +48,7 @@ async def test_connection_handler_info():
 async def test_connection_handler_forward():
 async def test_connection_handler_forward():
     handler = ConnectionHandler(
     handler = ConnectionHandler(
         DHT(start=True),
         DHT(start=True),
-        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
+        dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
     )
     )
     handler.start()
     handler.start()
 
 
@@ -109,7 +109,7 @@ async def test_connection_handler_forward():
 async def test_connection_handler_backward():
 async def test_connection_handler_backward():
     handler = ConnectionHandler(
     handler = ConnectionHandler(
         DHT(start=True),
         DHT(start=True),
-        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
+        dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
     )
     )
     handler.start()
     handler.start()
 
 
@@ -179,7 +179,7 @@ class DummyPool(TaskPool):
         return [inputs[0] * self.k]
         return [inputs[0] * self.k]
 
 
 
 
-class DummyExpertBackend(ExpertBackend):
+class DummyModuleBackend(ModuleBackend):
     def __init__(self, name: str, k: float):
     def __init__(self, name: str, k: float):
         self.name = name
         self.name = name
         self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
         self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]

+ 10 - 8
tests/test_expert_backend.py

@@ -5,7 +5,7 @@ import pytest
 import torch
 import torch
 from torch.nn import Linear
 from torch.nn import Linear
 
 
-from hivemind import BatchTensorDescriptor, ExpertBackend
+from hivemind import BatchTensorDescriptor, ModuleBackend
 from hivemind.moe.server.checkpoints import load_experts, store_experts
 from hivemind.moe.server.checkpoints import load_experts, store_experts
 from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
 from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
 
 
@@ -22,13 +22,15 @@ def example_experts():
     opt = torch.optim.SGD(expert.parameters(), PEAK_LR)
     opt = torch.optim.SGD(expert.parameters(), PEAK_LR)
 
 
     args_schema = (BatchTensorDescriptor(1),)
     args_schema = (BatchTensorDescriptor(1),)
-    expert_backend = ExpertBackend(
+    expert_backend = ModuleBackend(
         name=EXPERT_NAME,
         name=EXPERT_NAME,
-        expert=expert,
+        module=expert,
         optimizer=opt,
         optimizer=opt,
-        scheduler=get_linear_schedule_with_warmup,
-        num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
-        num_total_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
+        scheduler=get_linear_schedule_with_warmup(
+            opt,
+            num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
+            num_training_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
+        ),
         args_schema=args_schema,
         args_schema=args_schema,
         outputs_schema=BatchTensorDescriptor(1),
         outputs_schema=BatchTensorDescriptor(1),
         max_batch_size=1,
         max_batch_size=1,
@@ -39,7 +41,7 @@ def example_experts():
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_save_load_checkpoints(example_experts):
 def test_save_load_checkpoints(example_experts):
-    expert = example_experts[EXPERT_NAME].expert
+    expert = example_experts[EXPERT_NAME].module
 
 
     with TemporaryDirectory() as tmpdir:
     with TemporaryDirectory() as tmpdir:
         tmp_path = Path(tmpdir)
         tmp_path = Path(tmpdir)
@@ -79,7 +81,7 @@ def test_restore_update_count(example_experts):
             expert_backend.backward(batch, loss_grad)
             expert_backend.backward(batch, loss_grad)
 
 
         load_experts(example_experts, tmp_path)
         load_experts(example_experts, tmp_path)
-        assert expert_backend.update_count == BACKWARD_PASSES_BEFORE_SAVE
+        assert expert_backend.scheduler._step_count == BACKWARD_PASSES_BEFORE_SAVE + 1
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked

+ 4 - 4
tests/test_moe.py

@@ -7,7 +7,7 @@ from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
 from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.expert_uid import ExpertInfo
-from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
+from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts
 from hivemind.moe.server.layers import name_to_block
 from hivemind.moe.server.layers import name_to_block
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import BatchTensorDescriptor, get_dht_time
 from hivemind.utils import BatchTensorDescriptor, get_dht_time
@@ -257,16 +257,16 @@ def test_client_anomaly_detection():
     experts = {}
     experts = {}
     for i in range(4):
     for i in range(4):
         expert = name_to_block["ffn"](HID_DIM)
         expert = name_to_block["ffn"](HID_DIM)
-        experts[f"expert.{i}"] = ExpertBackend(
+        experts[f"expert.{i}"] = ModuleBackend(
             name=f"expert.{i}",
             name=f"expert.{i}",
-            expert=expert,
+            module=expert,
             optimizer=torch.optim.Adam(expert.parameters()),
             optimizer=torch.optim.Adam(expert.parameters()),
             args_schema=(BatchTensorDescriptor(HID_DIM),),
             args_schema=(BatchTensorDescriptor(HID_DIM),),
             outputs_schema=BatchTensorDescriptor(HID_DIM),
             outputs_schema=BatchTensorDescriptor(HID_DIM),
             max_batch_size=16,
             max_batch_size=16,
         )
         )
 
 
-    experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan")
+    experts["expert.3"].module.ffn.weight.data[0, 0] = float("nan")
 
 
     dht = DHT(start=True)
     dht = DHT(start=True)
     server = Server(dht, experts, num_connection_handlers=1)
     server = Server(dht, experts, num_connection_handlers=1)