Bladeren bron

Simplify ExpertBackend interface (#483)

- extract gradient clipping from ExpertBackend: this behavior can be achieved with a user-defined Optimizer
- remove stats from ExpertBackend: this behavior can be achieved with a user-defined Scheduler
- rename full_state -> state_dict, rationale: there is no "non-full" state in this context
- rename ExpertBackend.expert -> ExpertBackend.module to avoid confusion

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 3 jaren geleden
bovenliggende
commit
5ea21a75f4

+ 5 - 5
benchmarks/benchmark_throughput.py

@@ -10,7 +10,7 @@ from hivemind.dht import DHT
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
-from hivemind.moe.server import ExpertBackend, Server
+from hivemind.moe.server import ModuleBackend, Server
 from hivemind.moe.server.layers import name_to_block
 from hivemind.p2p import P2P
 from hivemind.utils.limits import increase_file_limit
@@ -118,12 +118,12 @@ def benchmark_throughput(
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
 
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-        experts = {}
+        module_backends = {}
         for i in range(num_experts):
             expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
-            experts[f"expert.{i}"] = ExpertBackend(
+            module_backends[f"expert.{i}"] = ModuleBackend(
                 name=f"expert.{i}",
-                expert=expert,
+                module=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
                 args_schema=(BatchTensorDescriptor(hid_dim),),
                 outputs_schema=BatchTensorDescriptor(hid_dim),
@@ -133,7 +133,7 @@ def benchmark_throughput(
 
         server = Server(
             dht=server_dht,
-            expert_backends=experts,
+            module_backends=module_backends,
             num_connection_handlers=num_handlers,
             device=device,
         )

+ 5 - 5
docs/modules/server.rst

@@ -9,9 +9,9 @@ or as a part of **hivemind.moe.client.RemoteMixtureOfExperts** that finds the mo
 The hivemind.moe.server module is organized as follows:
 
 - Server_ is the main class that publishes experts, accepts incoming requests, and passes them to Runtime_ for compute.
-- ExpertBackend_ is a wrapper for `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_ \
+- ModuleBackend_ is a wrapper for `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_ \
   that can be accessed by remote clients. It has two TaskPool_ s for forward and backward requests.
-- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert.
+- Runtime_ balances the device (GPU) usage between several ModuleBackend_ instances that each service one expert.
 - TaskPool_ stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches \
   and offers those batches to Runtime_ for processing.
 
@@ -25,9 +25,9 @@ The hivemind.moe.server module is organized as follows:
    :members:
    :member-order: bysource
 
-.. _ExpertBackend:
-.. autoclass:: ExpertBackend
-    :members: forward, backward, apply_gradients, get_info, get_pools
+.. _ModuleBackend:
+.. autoclass:: ModuleBackend
+    :members: forward, backward, on_backward, get_info, get_pools
     :member-order: bysource
 
 .. currentmodule:: hivemind.moe.server.runtime

+ 1 - 1
hivemind/__init__.py

@@ -2,7 +2,7 @@ from hivemind.averaging import DecentralizedAverager
 from hivemind.compression import *
 from hivemind.dht import DHT
 from hivemind.moe import (
-    ExpertBackend,
+    ModuleBackend,
     RemoteExpert,
     RemoteMixtureOfExperts,
     RemoteSwitchMixtureOfExperts,

+ 2 - 1
hivemind/hivemind_cli/run_server.py

@@ -54,7 +54,8 @@ def main():
                         help='Server will report experts to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,
                         help='DHT entries will expire after this many seconds')
-    parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
+    parser.add_argument('--num_training_steps', type=int, required=False, help='The total number of steps for LR schedule')
+
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
 
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],

+ 1 - 1
hivemind/moe/__init__.py

@@ -1,6 +1,6 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import (
-    ExpertBackend,
+    ModuleBackend,
     Server,
     background_server,
     declare_experts,

+ 1 - 1
hivemind/moe/server/__init__.py

@@ -1,4 +1,4 @@
 from hivemind.moe.server.dht_handler import declare_experts, get_experts
-from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.layers import register_expert_class
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.server import Server, background_server

+ 9 - 9
hivemind/moe/server/checkpoints.py

@@ -8,7 +8,7 @@ from typing import Dict
 
 import torch
 
-from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
@@ -34,23 +34,23 @@ def copy_tree(src: str, dst: str):
 
 
 class CheckpointSaver(threading.Thread):
-    def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: float):
+    def __init__(self, module_backends: Dict[str, ModuleBackend], checkpoint_dir: Path, update_period: float):
         super().__init__()
         assert is_directory(checkpoint_dir)
-        self.expert_backends = expert_backends
+        self.module_backends = module_backends
         self.update_period = update_period
         self.checkpoint_dir = checkpoint_dir
         self.stop = threading.Event()
 
         # create expert directories to ensure that the directory is writable and checkpoints can be loaded
-        store_experts(self.expert_backends, self.checkpoint_dir)
+        store_experts(self.module_backends, self.checkpoint_dir)
 
     def run(self) -> None:
         while not self.stop.wait(self.update_period):
-            store_experts(self.expert_backends, self.checkpoint_dir)
+            store_experts(self.module_backends, self.checkpoint_dir)
 
 
-def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
+def store_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path):
     logger.debug(f"Storing experts at {checkpoint_dir.absolute()}")
     assert is_directory(checkpoint_dir)
     timestamp = datetime.now().isoformat(sep="_")
@@ -59,17 +59,17 @@ def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
             expert_dir = Path(tmpdirname) / expert_name
             expert_dir.mkdir()
             checkpoint_name = expert_dir / f"checkpoint_{timestamp}.pt"
-            torch.save(expert_backend.get_full_state(), checkpoint_name)
+            torch.save(expert_backend.state_dict(), checkpoint_name)
             os.symlink(checkpoint_name, expert_dir / "checkpoint_last.pt")
         copy_tree(tmpdirname, str(checkpoint_dir))
 
 
-def load_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
+def load_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path):
     assert is_directory(checkpoint_dir)
     for expert_name, expert in experts.items():
         checkpoints_folder = checkpoint_dir / expert_name
         latest_checkpoint = checkpoints_folder / "checkpoint_last.pt"
         if latest_checkpoint.exists():
-            expert.load_full_state(torch.load(latest_checkpoint))
+            expert.load_state_dict(torch.load(latest_checkpoint))
         else:
             logger.warning(f"Failed to load checkpoint for expert {expert_name}")

+ 3 - 3
hivemind/moe/server/connection_handler.py

@@ -6,7 +6,7 @@ import torch
 
 from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.dht import DHT
-from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE, P2P
@@ -25,10 +25,10 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
 
     :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port
     :param dht: a running hivemind.dht.DHT, used to let other peers connect to this one
-    :param experts: a dict [UID -> ExpertBackend] with all active experts
+    :param experts: a dict [UID -> ModuleBackend] with all active experts
     """
 
-    def __init__(self, dht: DHT, experts: Dict[str, ExpertBackend]):
+    def __init__(self, dht: DHT, experts: Dict[str, ModuleBackend]):
         super().__init__()
         self.dht, self.experts = dht, experts
         self._p2p: Optional[P2P] = None

+ 6 - 4
hivemind/moe/server/dht_handler.py

@@ -20,20 +20,22 @@ from hivemind.utils import MAX_DHT_TIME_DISCREPANCY_SECONDS, MPFuture, get_dht_t
 
 
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs):
+    def __init__(
+        self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
+    ):
         super().__init__(**kwargs)
         if expiration is None:
             expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
-        self.experts = experts
+        self.module_backends = module_backends
         self.dht = dht
         self.update_period = update_period
         self.expiration = expiration
         self.stop = threading.Event()
 
     def run(self) -> None:
-        declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration)
+        declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration)
         while not self.stop.wait(self.update_period):
-            declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration)
+            declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration)
 
 
 def declare_experts(

+ 1 - 1
hivemind/moe/server/layers/dropout.py

@@ -19,7 +19,7 @@ class DeterministicDropoutFunction(torch.autograd.Function):
 class DeterministicDropout(nn.Module):
     """
     Custom dropout layer which accepts dropout mask as an input (drop_prob is only used for scaling input activations).
-    Can be used with RemoteExpert/ExpertBackend to ensure that dropout mask is the same at forward and backward steps
+    Can be used with RemoteExpert/ModuleBackend to ensure that dropout mask is the same at forward and backward steps
     """
 
     def __init__(self, drop_prob):

+ 58 - 0
hivemind/moe/server/layers/optim.py

@@ -0,0 +1,58 @@
+import torch
+
+
+class OptimizerWrapper(torch.optim.Optimizer):
+    """A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer"""
+
+    def __init__(self, optim: torch.optim.Optimizer):
+        super().__init__(optim.param_groups, optim.defaults)
+        self.optim = optim
+
+    @property
+    def defaults(self):
+        return self.optim.defaults
+
+    @property
+    def state(self):
+        return self.optim.state
+
+    def __getstate__(self):
+        return self.optim.__getstate__()
+
+    def __setstate__(self, state):
+        self.optim.__setstate__(state)
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({repr(self.optim)})"
+
+    def state_dict(self):
+        return self.optim.state_dict()
+
+    def load_state_dict(self, state_dict: dict) -> None:
+        return self.optim.load_state_dict(state_dict)
+
+    def step(self, *args, **kwargs):
+        return self.optim.step(*args, **kwargs)
+
+    def zero_grad(self, *args, **kwargs):
+        return self.optim.zero_grad(*args, **kwargs)
+
+    @property
+    def param_groups(self):
+        return self.optim.param_groups
+
+    def add_param_group(self, param_group: dict) -> None:
+        return self.optim.add_param_group(param_group)
+
+
+class ClippingWrapper(OptimizerWrapper):
+    """A wrapper of torch.Optimizer that clips gradients by global norm before each step"""
+
+    def __init__(self, optim: torch.optim.Optimizer, clip_grad_norm: float):
+        super().__init__(optim)
+        self.clip_grad_norm = clip_grad_norm
+
+    def step(self, *args, **kwargs):
+        parameters = tuple(param for group in self.param_groups for param in group["params"])
+        torch.nn.utils.clip_grad_norm_(parameters, self.clip_grad_norm)
+        return super().step(*args, **kwargs)

+ 46 - 84
hivemind/moe/server/expert_backend.py → hivemind/moe/server/module_backend.py

@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Sequence, Tuple, Union
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
 
 import torch
 from torch import nn
@@ -8,19 +8,20 @@ from hivemind.utils.logging import get_logger
 from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
 from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 logger = get_logger(__name__)
 
 
-class ExpertBackend:
+class ModuleBackend:
     """
-    ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime
-    By default, ExpertBackend handles three types of requests:
+    ModuleBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime
+    By default, ModuleBackend handles three types of requests:
 
      - forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization.
      - backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched.
      - get_info - return expert metadata. Not batched.
 
-    :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
+    :param module: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
 
      - Experts must always receive the same set of args and kwargs and produce output tensors of same type
      - All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
@@ -34,49 +35,37 @@ class ExpertBackend:
     :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
     :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
     :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
-    :param num_warmup_steps: the number of warmup steps for LR schedule
-    :param num_total_steps: the total number of steps for LR schedule
-    :param clip_grad_norm: maximum gradient norm used for clipping
     :param kwargs: extra parameters to be forwarded into TaskPool.__init__
     """
 
     def __init__(
         self,
         name: str,
-        expert: nn.Module,
-        optimizer: torch.optim.Optimizer,
+        module: nn.Module,
         *,
-        scheduler: Callable = None,
+        optimizer: Optional[torch.optim.Optimizer] = None,
+        scheduler: Optional[LRSchedulerBase] = None,
         args_schema: Tuple[BatchTensorDescriptor, ...] = None,
         kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
         outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
-        num_warmup_steps: int = None,
-        num_total_steps: int = None,
-        clip_grad_norm: float = None,
         **kwargs,
     ):
         super().__init__()
-        self.expert, self.optimizer, self.name = expert, optimizer, name
-
-        if scheduler is None:
-            self.scheduler = None
-        else:
-            assert optimizer is not None and num_warmup_steps is not None and num_total_steps is not None
-            self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_total_steps)
-        self.clip_grad_norm = clip_grad_norm
+        self.name, self.module, self.optimizer, self.scheduler = name, module, optimizer, scheduler
 
         self.args_schema = args_schema = tuple(args_schema or ())
         self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
         assert args_schema or kwargs_schema, (
-            "expert must receive at least one positional or keyword input."
+            f"Module must take at least one positional or keyword input."
             " Did you forget to provide args_schema/kwargs_schema?"
         )
+        assert optimizer is not None or scheduler is None, "scheduler should only be used if optimizer is not None"
 
         if outputs_schema is None:
             # run expert once to get outputs schema
             dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema)
             dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
-            dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
+            dummy_outputs = self.module(*dummy_args, **dummy_kwargs)
             outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
 
         self.forward_schema = (self.args_schema, self.kwargs_schema)  # inputs for forward
@@ -87,22 +76,17 @@ class ExpertBackend:
         self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
         self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
 
-        self.update_count = 0
-        self.examples_processed = 0
-
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
-        To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``.
+        To submit a request for asynchronous processing, please use ``ModuleBackend.forward_pool.submit_task``.
+
+        .. warning: if the underlying module performs non-gradient updates (e.g. batchnorm), it will be updated twice:
+           once during forward pass, and again during backward. This behavior is similar to gradient checkpointing.
 
         Subclassing:
            This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``;
-
            It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
-
-           .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice.
-           .. For now, either register all buffers as outputs or avoid stateful experts
-
         """
         args, kwargs = nested_pack(inputs, structure=self.forward_schema)
 
@@ -110,7 +94,7 @@ class ExpertBackend:
             raise RuntimeError("Batch should contain more than 0 samples")
 
         with torch.no_grad():
-            outputs = self.expert(*args, **kwargs)
+            outputs = self.module(*args, **kwargs)
 
         # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
         return tuple(nested_flatten(outputs))
@@ -118,7 +102,7 @@ class ExpertBackend:
     def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually
-        To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``.
+        To submit a request for asynchronous processing, please use ``ModuleBackend.backward_pool.submit_task``.
 
         Subclassing:
            This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``;
@@ -128,9 +112,7 @@ class ExpertBackend:
            Runtime doesn't guarantee that backward will be performed in the same order and for the same data
            as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
 
-           .. todo correct state handling (see forward)
-
-           Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train
+           Please make sure to call ``ModuleBackend.on_backward`` after each call to backward
         """
         (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
 
@@ -148,7 +130,7 @@ class ExpertBackend:
 
             batch_size = args[0].size(0)
 
-            outputs = self.expert(*args, **kwargs)
+            outputs = self.module(*args, **kwargs)
             assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
 
             outputs_flat = tuple(nested_flatten(outputs))
@@ -163,65 +145,45 @@ class ExpertBackend:
             torch.autograd.backward(
                 outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False
             )
-            self.apply_gradients(batch_size)
+            self.on_backward(batch_size)
 
         return tuple(
             x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs))
         )
 
-    def apply_gradients(self, batch_size) -> None:
+    def on_backward(self, batch_size: int) -> None:
         """
-        Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
+        Train the expert for one step. This method is called by ``ModuleBackend.backward`` after computing gradients.
         """
-        if self.clip_grad_norm is not None:
-            torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm)
-
-        self.optimizer.step()
-        self.optimizer.zero_grad()
+        if self.optimizer is not None:
+            self.optimizer.step()
+            self.optimizer.zero_grad()
 
         if self.scheduler is not None:
             self.scheduler.step()
 
-        self.update_count += 1
-        self.examples_processed += batch_size
-
-    def get_stats(self) -> Dict:
-        """
-        Return current expert training statistics (number of updates, number of processed examples after
-        last optimizer step)
-        """
-        return {"updates": self.update_count, "examples_processed": self.examples_processed}
-
-    def get_full_state(self) -> Dict:
-        """
-        Return the current state of the expert (including batch processing statistics)
-        """
-        full_state = {
-            "stats": self.get_stats(),
-            "model": self.expert.state_dict(),
-            "optimizer": self.optimizer.state_dict(),
-            "scheduler": {} if self.scheduler is None else self.scheduler.state_dict(),
-        }
+    def state_dict(self) -> Dict:
+        """Return the current state of the module, optimizer, and scheduler"""
+        full_state = dict(module=self.module.state_dict())
+        if self.optimizer is not None:
+            full_state["optimizer"] = self.optimizer.state_dict()
+        if self.scheduler is not None:
+            full_state["scheduler"] = self.scheduler.state_dict()
         return full_state
 
-    def load_full_state(self, state_dict: Dict):
-        if "stats" in state_dict:
-            self.update_count = state_dict["stats"]["updates"]
-            self.examples_processed = state_dict["stats"]["examples_processed"]
-        else:
-            logger.warning(f"Batch processing stats missing for expert {self.name}")
-
-        self.expert.load_state_dict(state_dict["model"])
+    def load_state_dict(self, state_dict: Dict):
+        self.module.load_state_dict(state_dict["module"])
+        if self.optimizer is not None:
+            if "optimizer" in state_dict:
+                self.optimizer.load_state_dict(state_dict["optimizer"])
+            else:
+                logger.warning(f"Optimizer state missing for {self.name}")
 
-        if "optimizer" in state_dict:
-            self.optimizer.load_state_dict(state_dict["optimizer"])
-        else:
-            logger.warning(f"Optimizer state missing for expert {self.name}")
-
-        if self.scheduler is not None and "scheduler" in state_dict:
-            self.scheduler.load_state_dict(state_dict["scheduler"])
-        else:
-            logger.warning(f"Learning rate scheduler state missing for expert {self.name}")
+        if self.scheduler is not None:
+            if "scheduler" in state_dict:
+                self.scheduler.load_state_dict(state_dict["scheduler"])
+            else:
+                logger.warning(f"Learning rate scheduler state missing for {self.name}")
 
     def get_info(self) -> Dict[str, Any]:
         """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""

+ 11 - 11
hivemind/moe/server/runtime.py

@@ -12,7 +12,7 @@ from typing import Dict, NamedTuple, Optional
 import torch
 from prefetch_generator import BackgroundGenerator
 
-from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
 
 logger = get_logger(__name__)
@@ -20,20 +20,20 @@ logger = get_logger(__name__)
 
 class Runtime(threading.Thread):
     """
-    A group of processes that processes incoming requests for multiple experts on a shared device.
+    A group of processes that processes incoming requests for multiple module backends on a shared device.
     Runtime is usually created and managed by Server, humans need not apply.
 
     For debugging, you can start runtime manually with .start() or .run()
 
-    >>> expert_backends = {'expert_name': ExpertBackend(**kwargs)}
-    >>> runtime = Runtime(expert_backends)
+    >>> module_backends = {'expert_name': ModuleBackend(**kwargs)}
+    >>> runtime = Runtime(module_backends)
     >>> runtime.start()  # start runtime in background thread. To start in current thread, use runtime.run()
     >>> runtime.ready.wait()  # await for runtime to load all experts on device and create request pools
-    >>> future = runtime.expert_backends['expert_name'].forward_pool.submit_task(*expert_inputs)
+    >>> future = runtime.module_backends['expert_name'].forward_pool.submit_task(*module_inputs)
     >>> print("Returned:", future.result())
     >>> runtime.shutdown()
 
-    :param expert_backends: a dict [expert uid -> ExpertBackend]
+    :param module_backends: a dict [expert uid -> ModuleBackend]
     :param prefetch_batches: form up to this many batches in advance
     :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
     :param device: if specified, moves all experts and data to this device via .to(device=device).
@@ -46,15 +46,15 @@ class Runtime(threading.Thread):
 
     def __init__(
         self,
-        expert_backends: Dict[str, ExpertBackend],
+        module_backends: Dict[str, ModuleBackend],
         prefetch_batches=64,
         sender_threads: int = 1,
         device: torch.device = None,
         stats_report_interval: Optional[int] = None,
     ):
         super().__init__()
-        self.expert_backends = expert_backends
-        self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
+        self.module_backends = module_backends
+        self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
         self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
         self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
         self.shutdown_trigger = mp.Event()
@@ -69,8 +69,8 @@ class Runtime(threading.Thread):
             if not pool.is_alive():
                 pool.start()
         if self.device is not None:
-            for expert_backend in self.expert_backends.values():
-                expert_backend.expert.to(self.device)
+            for backend in self.module_backends.values():
+                backend.module.to(self.device)
 
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
             try:

+ 33 - 28
hivemind/moe/server/server.py

@@ -15,13 +15,14 @@ from hivemind.moe.expert_uid import UID_DELIMITER
 from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts
-from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.layers import (
     add_custom_models_from_file,
     name_to_block,
     name_to_input,
     schedule_name_to_scheduler,
 )
+from hivemind.moe.server.layers.optim import ClippingWrapper
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.runtime import Runtime
 from hivemind.p2p import PeerInfo
 from hivemind.proto.runtime_pb2 import CompressionType
@@ -33,7 +34,7 @@ logger = get_logger(__name__)
 
 class Server(threading.Thread):
     """
-    Server allows you to host "experts" - pytorch subnetworks used by Decentralized Mixture of Experts.
+    Server allows you to host "experts" - pytorch subnetworks that can be accessed remotely by peers.
     After creation, a server should be started: see Server.run or Server.run_in_background.
 
     A working server does two things:
@@ -41,7 +42,7 @@ class Server(threading.Thread):
      - publishes updates to expert status every :update_period: seconds
 
     :type dht: an instance of hivemind.DHT. Server will use DHT for all network interactions.
-    :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
+    :param module_backends: dict{expert uid (str) : ModuleBackend} for all expert hosted by this server.
     :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
         if too small for normal functioning, we recommend 4 handlers per expert backend.
     :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
@@ -54,7 +55,7 @@ class Server(threading.Thread):
     def __init__(
         self,
         dht: DHT,
-        expert_backends: Dict[str, ExpertBackend],
+        module_backends: Dict[str, ModuleBackend],
         num_connection_handlers: int = 1,
         update_period: float = 30,
         expiration: Optional[float] = None,
@@ -63,18 +64,18 @@ class Server(threading.Thread):
         **kwargs,
     ):
         super().__init__()
-        self.dht, self.experts, self.update_period = dht, expert_backends, update_period
+        self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
 
-        self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(num_connection_handlers)]
+        self.conn_handlers = [ConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
         if checkpoint_dir is not None:
-            self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
+            self.checkpoint_saver = CheckpointSaver(module_backends, checkpoint_dir, update_period)
         else:
             self.checkpoint_saver = None
-        self.runtime = Runtime(self.experts, **kwargs)
+        self.runtime = Runtime(self.module_backends, **kwargs)
 
-        if self.experts:
+        if self.module_backends:
             self.dht_handler_thread = DHTHandlerThread(
-                experts=self.experts,
+                module_backends=self.module_backends,
                 dht=self.dht,
                 update_period=self.update_period,
                 expiration=expiration,
@@ -95,7 +96,7 @@ class Server(threading.Thread):
         optim_cls=torch.optim.Adam,
         scheduler: str = "none",
         num_warmup_steps=None,
-        num_total_steps=None,
+        num_training_steps=None,
         clip_grad_norm=None,
         num_handlers=None,
         min_batch_size=1,
@@ -113,7 +114,7 @@ class Server(threading.Thread):
         **kwargs,
     ) -> Server:
         """
-        Instantiate a server with several identical experts. See argparse comments below for details
+        Instantiate a server with several identical modules. See argparse comments below for details
 
         :param num_experts: run this many identical experts
         :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
@@ -129,7 +130,7 @@ class Server(threading.Thread):
         :param optim_cls: uses this optimizer to train all experts
         :param scheduler: if not `none`, the name of the expert LR scheduler
         :param num_warmup_steps: the number of warmup steps for LR schedule
-        :param num_total_steps: the total number of steps for LR schedule
+        :param num_training_steps: the total number of steps for LR schedule
         :param clip_grad_norm: maximum gradient norm used for clipping
 
         :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
@@ -138,7 +139,7 @@ class Server(threading.Thread):
 
         :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
             hosted on this server. For a more fine-grained compression, start server in python and specify compression
-            for each BatchTensorProto in ExpertBackend for the respective experts.
+            for each BatchTensorProto in ModuleBackend for the respective experts.
 
         :param start: if True, starts server right away and returns when server is ready for requests
         :param stats_report_interval: interval between two reports of batch processing performance statistics
@@ -180,7 +181,6 @@ class Server(threading.Thread):
 
         num_experts = len(expert_uids)
         num_handlers = num_handlers if num_handlers is not None else num_experts * 8
-        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 
         sample_input = name_to_input[expert_cls](DUMMY_BATCH_SIZE, hidden_dim)
@@ -189,21 +189,26 @@ class Server(threading.Thread):
         else:
             args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
 
-        scheduler = schedule_name_to_scheduler[scheduler]
+        scheduler_cls = schedule_name_to_scheduler[scheduler]
+        if scheduler_cls is not None:
+            scheduler_cls = partial(
+                scheduler_cls, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
+            )
 
         # initialize experts
         experts = {}
         for expert_uid in expert_uids:
             expert = name_to_block[expert_cls](hidden_dim)
-            experts[expert_uid] = ExpertBackend(
+            optimizer = optim_cls(expert.parameters()) if optim_cls is not None else None
+            scheduler = scheduler_cls(optimizer) if scheduler_cls is not None else None
+            if clip_grad_norm is not None:
+                optimizer = ClippingWrapper(optimizer, clip_grad_norm)
+            experts[expert_uid] = ModuleBackend(
                 name=expert_uid,
-                expert=expert,
+                module=expert,
                 args_schema=args_schema,
-                optimizer=optim_cls(expert.parameters()),
+                optimizer=optimizer,
                 scheduler=scheduler,
-                num_warmup_steps=num_warmup_steps,
-                num_total_steps=num_total_steps,
-                clip_grad_norm=clip_grad_norm,
                 min_batch_size=min_batch_size,
                 max_batch_size=max_batch_size,
             )
@@ -228,15 +233,15 @@ class Server(threading.Thread):
         Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
         runs Runtime (self.runtime) to process incoming requests.
         """
-        logger.info(f"Server started with {len(self.experts)} experts:")
-        for expert_name, backend in self.experts.items():
-            num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
-            logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
+        logger.info(f"Server started with {len(self.module_backends)} modules:")
+        for expert_name, backend in self.module_backends.items():
+            num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
+            logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
 
         if not self.dht.is_alive():
             self.dht.run_in_background(await_ready=True)
 
-        if self.experts:
+        if self.module_backends:
             self.dht_handler_thread.start()
 
         if self.checkpoint_saver is not None:
@@ -287,7 +292,7 @@ class Server(threading.Thread):
             process.join()
         logger.debug("Connection handlers terminated")
 
-        if self.experts:
+        if self.module_backends:
             self.dht_handler_thread.stop.set()
             self.dht_handler_thread.join()
 

+ 5 - 5
tests/test_connection_handler.py

@@ -10,7 +10,7 @@ import torch
 from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.moe.server.connection_handler import ConnectionHandler
-from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PHandlerError
 from hivemind.proto import runtime_pb2
@@ -25,7 +25,7 @@ from hivemind.utils.tensor_descr import BatchTensorDescriptor
 async def test_connection_handler_info():
     handler = ConnectionHandler(
         DHT(start=True),
-        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
+        dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
     )
     handler.start()
 
@@ -48,7 +48,7 @@ async def test_connection_handler_info():
 async def test_connection_handler_forward():
     handler = ConnectionHandler(
         DHT(start=True),
-        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
+        dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
     )
     handler.start()
 
@@ -109,7 +109,7 @@ async def test_connection_handler_forward():
 async def test_connection_handler_backward():
     handler = ConnectionHandler(
         DHT(start=True),
-        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
+        dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
     )
     handler.start()
 
@@ -179,7 +179,7 @@ class DummyPool(TaskPool):
         return [inputs[0] * self.k]
 
 
-class DummyExpertBackend(ExpertBackend):
+class DummyModuleBackend(ModuleBackend):
     def __init__(self, name: str, k: float):
         self.name = name
         self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]

+ 10 - 8
tests/test_expert_backend.py

@@ -5,7 +5,7 @@ import pytest
 import torch
 from torch.nn import Linear
 
-from hivemind import BatchTensorDescriptor, ExpertBackend
+from hivemind import BatchTensorDescriptor, ModuleBackend
 from hivemind.moe.server.checkpoints import load_experts, store_experts
 from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
 
@@ -22,13 +22,15 @@ def example_experts():
     opt = torch.optim.SGD(expert.parameters(), PEAK_LR)
 
     args_schema = (BatchTensorDescriptor(1),)
-    expert_backend = ExpertBackend(
+    expert_backend = ModuleBackend(
         name=EXPERT_NAME,
-        expert=expert,
+        module=expert,
         optimizer=opt,
-        scheduler=get_linear_schedule_with_warmup,
-        num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
-        num_total_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
+        scheduler=get_linear_schedule_with_warmup(
+            opt,
+            num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
+            num_training_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
+        ),
         args_schema=args_schema,
         outputs_schema=BatchTensorDescriptor(1),
         max_batch_size=1,
@@ -39,7 +41,7 @@ def example_experts():
 
 @pytest.mark.forked
 def test_save_load_checkpoints(example_experts):
-    expert = example_experts[EXPERT_NAME].expert
+    expert = example_experts[EXPERT_NAME].module
 
     with TemporaryDirectory() as tmpdir:
         tmp_path = Path(tmpdir)
@@ -79,7 +81,7 @@ def test_restore_update_count(example_experts):
             expert_backend.backward(batch, loss_grad)
 
         load_experts(example_experts, tmp_path)
-        assert expert_backend.update_count == BACKWARD_PASSES_BEFORE_SAVE
+        assert expert_backend.scheduler._step_count == BACKWARD_PASSES_BEFORE_SAVE + 1
 
 
 @pytest.mark.forked

+ 4 - 4
tests/test_moe.py

@@ -7,7 +7,7 @@ from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
 from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.expert_uid import ExpertInfo
-from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
+from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts
 from hivemind.moe.server.layers import name_to_block
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import BatchTensorDescriptor, get_dht_time
@@ -257,16 +257,16 @@ def test_client_anomaly_detection():
     experts = {}
     for i in range(4):
         expert = name_to_block["ffn"](HID_DIM)
-        experts[f"expert.{i}"] = ExpertBackend(
+        experts[f"expert.{i}"] = ModuleBackend(
             name=f"expert.{i}",
-            expert=expert,
+            module=expert,
             optimizer=torch.optim.Adam(expert.parameters()),
             args_schema=(BatchTensorDescriptor(HID_DIM),),
             outputs_schema=BatchTensorDescriptor(HID_DIM),
             max_batch_size=16,
         )
 
-    experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan")
+    experts["expert.3"].module.ffn.weight.data[0, 0] = float("nan")
 
     dht = DHT(start=True)
     server = Server(dht, experts, num_connection_handlers=1)