浏览代码

Merge branch 'calm-v1.1' into calm

justheuristic 3 年之前
父节点
当前提交
fa93bc53e2
共有 39 个文件被更改,包括 781 次插入688 次删除
  1. 4 1
      .github/workflows/check-style.yml
  2. 0 1
      .github/workflows/push-docker-image.yml
  3. 4 2
      .github/workflows/run-benchmarks.yml
  4. 4 2
      .github/workflows/run-tests.yml
  5. 4 0
      README.md
  6. 1 2
      benchmarks/benchmark_optimizer.py
  7. 13 13
      benchmarks/benchmark_throughput.py
  8. 2 2
      docs/modules/optim.rst
  9. 8 6
      docs/modules/server.rst
  10. 24 24
      examples/albert/README.md
  11. 14 13
      examples/albert/arguments.py
  12. 5 5
      examples/albert/requirements.txt
  13. 97 74
      examples/albert/run_trainer.py
  14. 17 20
      examples/albert/run_training_monitor.py
  15. 1 1
      examples/albert/tokenize_wikitext103.py
  16. 1 1
      hivemind/__init__.py
  17. 9 5
      hivemind/averaging/allreduce.py
  18. 2 4
      hivemind/averaging/matchmaking.py
  19. 9 0
      hivemind/averaging/partition.py
  20. 1 1
      hivemind/hivemind_cli/run_server.py
  21. 8 1
      hivemind/moe/__init__.py
  22. 4 4
      hivemind/moe/client/moe.py
  23. 3 355
      hivemind/moe/server/__init__.py
  24. 2 2
      hivemind/moe/server/expert_backend.py
  25. 2 73
      hivemind/moe/server/expert_uid.py
  26. 419 0
      hivemind/moe/server/server.py
  27. 1 1
      hivemind/optim/__init__.py
  28. 8 0
      hivemind/optim/base.py
  29. 5 5
      hivemind/optim/collaborative.py
  30. 0 0
      hivemind/optim/experimental/__init__.py
  31. 0 0
      hivemind/optim/grad_averager.py
  32. 16 5
      hivemind/optim/grad_scaler.py
  33. 23 11
      hivemind/optim/optimizer.py
  34. 17 12
      hivemind/optim/progress_tracker.py
  35. 9 0
      hivemind/optim/state_averager.py
  36. 5 0
      hivemind/p2p/p2p_daemon.py
  37. 4 4
      hivemind/utils/tensor_descr.py
  38. 31 34
      tests/test_moe.py
  39. 4 4
      tests/test_optimizer.py

+ 4 - 1
.github/workflows/check-style.yml

@@ -1,6 +1,9 @@
 name: Check style
 
-on: [ push, pull_request ]
+on:
+  push:
+    branches: [ master ]
+  pull_request:
 
 jobs:
   black:

+ 0 - 1
.github/workflows/push-docker-image.yml

@@ -8,7 +8,6 @@ on:
   pull_request:
     branches: [ master ]
 
-
 jobs:
   build:
     runs-on: ubuntu-latest

+ 4 - 2
.github/workflows/run-benchmarks.yml

@@ -1,7 +1,9 @@
 name: Benchmarks
 
-on: [ push, pull_request ]
-
+on:
+  push:
+    branches: [ master ]
+  pull_request:
 
 jobs:
   run_benchmarks:

+ 4 - 2
.github/workflows/run-tests.yml

@@ -1,7 +1,9 @@
 name: Tests
 
-on: [ push, pull_request ]
-
+on:
+  push:
+    branches: [ master ]
+  pull_request:
 
 jobs:
   run_tests:

+ 4 - 0
README.md

@@ -12,6 +12,10 @@ large model on hundreds of computers from different universities, companies, and
 
 ![img](https://i.imgur.com/GPxolxb.gif)
 
+## Live Demo
+
+Check out our NeurIPS 2021 demonstration ["Training Transformers Together"](https://training-transformers-together.github.io/) to see hivemind in action, join an ongoing collaborative experiment, and learn more about the technologies behind it!
+
 ## Key Features
 
 * Distributed training without a master node: Distributed Hash Table allows connecting computers in a decentralized

+ 1 - 2
benchmarks/benchmark_optimizer.py

@@ -6,7 +6,6 @@ from dataclasses import dataclass
 from functools import partial
 from typing import Callable
 
-import numpy as np
 import torch
 import torchvision
 from torch import nn as nn
@@ -14,7 +13,7 @@ from torch.nn import functional as F
 from torch.utils.data import Dataset
 
 import hivemind
-from hivemind.optim.experimental.optimizer import Optimizer
+from hivemind.optim.optimizer import Optimizer
 from hivemind.utils.crypto import RSAPrivateKey
 
 

+ 13 - 13
benchmarks/benchmark_throughput.py

@@ -6,11 +6,13 @@ import time
 
 import torch
 
-import hivemind
-from hivemind import get_free_port
-from hivemind.moe.server import layers
+from hivemind.moe.client import RemoteExpert
+from hivemind.moe.server import ExpertBackend, Server
+from hivemind.moe.server.layers import name_to_block
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import LOCALHOST, get_free_port
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -32,9 +34,7 @@ def print_device_info(device=None):
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
     torch.set_num_threads(1)
     can_start.wait()
-    experts = [
-        hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)
-    ]
+    experts = [RemoteExpert(f"expert{i}", endpoint=f"{LOCALHOST}:{port}") for i in range(num_experts)]
 
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
@@ -66,7 +66,7 @@ def benchmark_throughput(
         or not torch.cuda.is_initialized()
         or torch.device(device) == torch.device("cpu")
     )
-    assert expert_cls in layers.name_to_block
+    assert expert_cls in name_to_block
     port = port or get_free_port()
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
@@ -105,20 +105,20 @@ def benchmark_throughput(
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         experts = {}
         for i in range(num_experts):
-            expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
-            experts[f"expert{i}"] = hivemind.ExpertBackend(
+            expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
+            experts[f"expert{i}"] = ExpertBackend(
                 name=f"expert{i}",
                 expert=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
-                args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
-                outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
+                args_schema=(BatchTensorDescriptor(hid_dim),),
+                outputs_schema=BatchTensorDescriptor(hid_dim),
                 max_batch_size=max_batch_size,
             )
         timestamps["created_experts"] = time.perf_counter()
-        server = hivemind.moe.Server(
+        server = Server(
             None,
             experts,
-            listen_on=f"{hivemind.LOCALHOST}:{port}",
+            listen_on=f"{LOCALHOST}:{port}",
             num_connection_handlers=num_handlers,
             device=device,
         )

+ 2 - 2
docs/modules/optim.rst

@@ -9,8 +9,8 @@
 
   <br><br>
 
-.. automodule:: hivemind.optim.experimental.optimizer
-.. currentmodule:: hivemind.optim.experimental.optimizer
+.. automodule:: hivemind.optim.optimizer
+.. currentmodule:: hivemind.optim.optimizer
 
 **hivemind.Optimizer**
 ----------------------

+ 8 - 6
docs/modules/server.rst

@@ -9,9 +9,9 @@ or as a part of **hivemind.moe.client.RemoteMixtureOfExperts** that finds the mo
 The hivemind.moe.server module is organized as follows:
 
 - Server_ is the main class that publishes experts, accepts incoming requests, and passes them to Runtime_ for compute.
-- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert.
 - ExpertBackend_ is a wrapper for `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_ \
   that can be accessed by remote clients. It has two TaskPool_ s for forward and backward requests.
+- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert.
 - TaskPool_ stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches \
   and offers those batches to Runtime_ for processing.
 
@@ -25,16 +25,18 @@ The hivemind.moe.server module is organized as follows:
    :members:
    :member-order: bysource
 
-.. _Runtime:
-.. autoclass:: Runtime
-    :members:
-    :member-order: bysource
-
 .. _ExpertBackend:
 .. autoclass:: ExpertBackend
     :members: forward, backward, apply_gradients, get_info, get_pools
     :member-order: bysource
 
+.. currentmodule:: hivemind.moe.server.runtime
+
+.. _Runtime:
+.. autoclass:: Runtime
+    :members:
+    :member-order: bysource
+
 .. currentmodule:: hivemind.moe.server.task_pool
 
 .. _TaskPool:

+ 24 - 24
examples/albert/README.md

@@ -9,7 +9,7 @@ using `hivemind.CollaborativeOptimizer` to exchange information between peers.
 
 * Install hivemind: `pip install git+https://github.com/learning-at-home/hivemind.git`
 * Dependencies: `pip install -r requirements.txt`
-* Preprocess data: `python tokenize_wikitext103.py`
+* Preprocess data: `./tokenize_wikitext103.py`
 * Upload the data to a publicly available location or ask volunteers to preprocess it locally
 
 ## Running an experiment
@@ -20,18 +20,16 @@ Run the first DHT peer to welcome trainers and record training statistics (e.g.,
 
 - In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics. If you're unfamiliar with Weights
   & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
-- Run `python run_training_monitor.py --experiment_prefix YOUR_EXPERIMENT_NAME --wandb_project YOUR_WANDB_PROJECT`
+- Run `./run_training_monitor.py --wandb_project YOUR_WANDB_PROJECT`
 
-  - `YOUR_EXPERIMENT_NAME` must be a unique name of this training run, e.g. `my-albert-v1`. It cannot contain `.`
-    due to naming conventions.
   - `YOUR_WANDB_PROJECT` is a name of wandb project used to track training metrics. Multiple experiments can have the
     same project name.
 
 ```
-$ python run_training_monitor.py --experiment_prefix my-albert-v1 --wandb_project Demo-run
-Oct 14 16:26:36.083 [INFO] [utils.log_visible_maddrs:47] Running a DHT peer. To connect other peers to this one over the Internet,
+$ ./run_training_monitor.py --wandb_project Demo-run
+Oct 14 16:26:36.083 [INFO] Running a DHT peer. To connect other peers to this one over the Internet,
 use --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
-Oct 14 16:26:36.083 [INFO] [utils.log_visible_maddrs:50] Full list of visible multiaddresses: ...
+Oct 14 16:26:36.083 [INFO] Full list of visible multiaddresses: ...
 wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
 wandb: Tracking run with wandb version 0.10.32
 wandb: Syncing run dry-mountain-2
@@ -39,12 +37,12 @@ wandb:  View project at https://wandb.ai/XXX/Demo-run
 wandb:  View run at https://wandb.ai/XXX/Demo-run/runs/YYY
 wandb: Run data is saved locally in /path/to/run/data
 wandb: Run `wandb offline` to turn off syncing.
-Oct 14 16:26:41.064 [INFO] [optim.collaborative._fetch_state:448] Found no active peers: None
-Oct 14 16:26:44.068 [INFO] [optim.collaborative._fetch_state:448] Found no active peers: None
+Oct 14 16:26:41.064 [INFO] Found no active peers: None
+Oct 14 16:26:44.068 [INFO] Found no active peers: None
 ...
-Oct 14 16:37:37.246 [INFO] [__main__.<module>:209] Step #1  loss = 11.05164
-Oct 14 16:39:37.441 [INFO] [__main__.<module>:209] Step #2  loss = 11.03771
-Oct 14 16:40:37.541 [INFO] [__main__.<module>:209] Step #3  loss = 11.02886
+Oct 14 16:37:37.246 [INFO] Step #1  loss = 11.05164
+Oct 14 16:39:37.441 [INFO] Step #2  loss = 11.03771
+Oct 14 16:40:37.541 [INFO] Step #3  loss = 11.02886
 ```
 
 ### GPU trainers
@@ -57,8 +55,8 @@ To join the collaboration with a GPU trainer,
   (see [default paths](./arguments.py#L117-L134) for reference)
 - Run:
   ```bash
-  python run_trainer.py \
-      --experiment_prefix YOUR_EXPERIMENT_NAME --initial_peers ONE_OR_MORE_PEERS \
+  ./run_trainer.py \
+      --initial_peers ONE_OR_MORE_PEERS \
       --logging_first_step --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
   ```
 
@@ -89,16 +87,18 @@ See the ["Tips and tricks"](#tips-and-tricks) section for more information on se
 As the peer begins training, it will periodically report training logs in the following form:
 
 ```
-... [INFO] [...] my-albert-v1 accumulated 448 samples from 17 peers for step #0. ETA 18.88 sec (refresh in 15.73 sec)
-... [INFO] [...] my-albert-v1 accumulated 4096 samples from 16 peers for step #0. ETA 0.00 sec (refresh in 0.50 sec)
-... [INFO] [optim.collaborative.step:283] Averaged tensors successfully with 17 peers
-... [INFO] [optim.collaborative.step:317] Optimizer step: done!
-Oct 14 18:58:03.750 [INFO] [__main__.on_step_end:141] Step 1
-Oct 14 18:58:03.750 [INFO] [__main__.on_step_end:142] Your current contribution: 892 samples
-Oct 14 18:58:03.750 [INFO] [__main__.on_step_end:143] Local loss: 11.023
+Dec 28 00:15:31.482 [INFO] albert accumulated 4056 samples for epoch #0 from 2 peers. ETA 0.75 sec (refresh in 0.50 sec)
+Dec 28 00:15:31.990 [INFO] albert accumulated 4072 samples for epoch #0 from 2 peers. ETA 0.24 sec (refresh in 0.50 sec)
+...
+Dec 28 00:15:32.857 [INFO] Step #1
+Dec 28 00:15:32.857 [INFO] Your current contribution: 2144 samples
+Dec 28 00:15:32.857 [INFO] Performance: 20.924 samples/sec
+Dec 28 00:15:32.857 [INFO] Local loss: 11.06709
+Dec 28 00:15:33.580 [INFO] Averaged gradients with 2 peers
+Dec 28 00:15:38.336 [INFO] Averaged parameters with 2 peers
 ```
 
-__Sanity check:__ a healthy peer will periodically report `Averaged tensors successfully with [N > 1]` peers.
+__Sanity check:__ a healthy peer will periodically report `Averaged gradients/parameters with [N > 1]` peers.
 
 For convenience, you can view (and share!) the learning curves of your collaborative experiments in wandb:
 
@@ -169,8 +169,8 @@ Here's an example of a full trainer script for Google Colab:
 !pip install transformers datasets sentencepiece torch_optimizer==0.1.0
 !git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e .
 !curl -L YOUR_HOSTED_DATA | tar xzf -
-!ulimit -n 4096 && python ./hivemind/examples/albert/run_trainer.py \
-    --experiment_prefix YOUR_EXPERIMENT_NAME --initial_peers ONE_OR_MORE_PEERS \
+!ulimit -n 4096 && ./hivemind/examples/albert/run_trainer.py \
+    --initial_peers ONE_OR_MORE_PEERS \
     --logging_dir ./logs --logging_first_step --output_dir ./outputs --overwrite_output_dir \
     --client_mode --averaging_expiration 10 --batch_size_lead 300 --gradient_accumulation_steps 1
 ```

+ 14 - 13
examples/albert/arguments.py

@@ -7,7 +7,7 @@ from transformers import TrainingArguments
 @dataclass
 class BaseTrainingArguments:
     experiment_prefix: str = field(
-        metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
+        default="albert", metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
     )
     initial_peers: List[str] = field(
         default_factory=list,
@@ -45,12 +45,11 @@ class BaseTrainingArguments:
 
 @dataclass
 class AveragerArguments:
-    averaging_expiration: float = field(
-        default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
-    )
-    averaging_timeout: float = field(
-        default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
-    )
+    target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
+
+
+@dataclass
+class ProgressTrackerArguments:
     min_refresh_period: float = field(
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
     )
@@ -66,17 +65,13 @@ class AveragerArguments:
     expected_drift_rate: float = field(
         default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
     )
-    performance_ema_alpha: float = field(
-        default=0.1, metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
-    )
-    target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
     metadata_expiration: float = field(
         default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
     )
 
 
 @dataclass
-class CollaborativeOptimizerArguments:
+class OptimizerArguments:
     target_batch_size: int = field(
         default=4096,
         metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
@@ -93,10 +88,16 @@ class CollaborativeOptimizerArguments:
         default=100.0,
         metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
     )
+    averaging_timeout: float = field(
+        default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
+    )
+    matchmaking_time: float = field(
+        default=5.0, metadata={"help": "When looking for group, wait for requests for at least this many seconds"}
+    )
 
 
 @dataclass
-class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArguments):
+class CollaborationArguments(OptimizerArguments, BaseTrainingArguments):
     statistics_expiration: float = field(
         default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
     )

+ 5 - 5
examples/albert/requirements.txt

@@ -1,7 +1,7 @@
-transformers>=4.6.0
-datasets>=1.5.0
-torch_optimizer>=0.1.0
-wandb>=0.10.26
+transformers==4.6.0
+datasets==1.5.0
+torch_optimizer==0.1.0
+wandb==0.10.26
 sentencepiece
 requests
-nltk>=3.6.2
+nltk==3.6.5

+ 97 - 74
examples/albert/run_trainer.py

@@ -1,7 +1,8 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 
 import os
 import pickle
+import sys
 from dataclasses import asdict
 from pathlib import Path
 
@@ -16,11 +17,17 @@ from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 
-import hivemind
+from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 import utils
-from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
+from arguments import (
+    AlbertTrainingArguments,
+    AveragerArguments,
+    CollaborationArguments,
+    DatasetArguments,
+    ProgressTrackerArguments,
+)
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -52,36 +59,6 @@ def get_model(training_args, config, tokenizer):
     return model
 
 
-def get_optimizer_and_scheduler(training_args, model):
-    no_decay = ["bias", "LayerNorm.weight"]
-    optimizer_grouped_parameters = [
-        {
-            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
-            "weight_decay": training_args.weight_decay,
-        },
-        {
-            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
-            "weight_decay": 0.0,
-        },
-    ]
-
-    opt = Lamb(
-        optimizer_grouped_parameters,
-        lr=training_args.learning_rate,
-        betas=(training_args.adam_beta1, training_args.adam_beta2),
-        eps=training_args.adam_epsilon,
-        weight_decay=training_args.weight_decay,
-        clamp_value=training_args.clamp_value,
-        debias=True,
-    )
-
-    scheduler = get_linear_schedule_with_warmup(
-        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
-    )
-
-    return opt, scheduler
-
-
 class CollaborativeCallback(transformers.TrainerCallback):
     """
     This callback monitors and reports collaborative training progress.
@@ -90,8 +67,8 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
     def __init__(
         self,
-        dht: hivemind.DHT,
-        optimizer: hivemind.CollaborativeOptimizer,
+        dht: DHT,
+        optimizer: Optimizer,
         model: torch.nn.Module,
         local_public_key: bytes,
         statistics_expiration: float,
@@ -99,7 +76,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
     ):
         super().__init__()
         self.model = model
-        self.dht, self.collaborative_optimizer = dht, optimizer
+        self.dht, self.optimizer = dht, optimizer
         self.local_public_key = local_public_key
         self.statistics_expiration = statistics_expiration
         self.last_reported_collaboration_step = -1
@@ -114,7 +91,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
         self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
     ):
         logger.info("Loading state from peers")
-        self.collaborative_optimizer.load_state_from_peers()
+        self.optimizer.load_state_from_peers()
 
     def on_step_end(
         self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
@@ -124,40 +101,43 @@ class CollaborativeCallback(transformers.TrainerCallback):
             self.restore_from_backup(self.latest_backup)
             return control
 
+        local_progress = self.optimizer.local_progress
+
         if state.log_history:
             self.loss += state.log_history[-1]["loss"]
             self.steps += 1
-            if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
-                self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
+
+            if self.optimizer.local_epoch != self.last_reported_collaboration_step:
+                self.last_reported_collaboration_step = self.optimizer.local_epoch
                 self.total_samples_processed += self.samples
-                samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
+                samples_per_second = local_progress.samples_per_second
                 statistics = utils.LocalMetrics(
-                    step=self.collaborative_optimizer.local_step,
+                    step=self.optimizer.local_epoch,
                     samples_per_second=samples_per_second,
                     samples_accumulated=self.samples,
                     loss=self.loss,
                     mini_steps=self.steps,
                 )
-                logger.info(f"Step #{self.collaborative_optimizer.local_step}")
+                logger.info(f"Step #{self.optimizer.local_epoch}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
-                logger.info(f"Performance: {samples_per_second} samples per second.")
+                logger.info(f"Performance: {samples_per_second:.3f} samples/sec")
                 if self.steps:
-                    logger.info(f"Local loss: {self.loss / self.steps}")
-                if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
+                    logger.info(f"Local loss: {self.loss / self.steps:.5f}")
+                if self.optimizer.local_epoch % self.backup_every_steps == 0:
                     self.latest_backup = self.backup_state()
 
                 self.loss = 0
                 self.steps = 0
-                if self.collaborative_optimizer.is_synchronized:
+                if self.optimizer.is_synchronized_with_peers():
                     self.dht.store(
-                        key=self.collaborative_optimizer.prefix + "_metrics",
+                        key=self.optimizer.run_id + "_metrics",
                         subkey=self.local_public_key,
                         value=statistics.dict(),
-                        expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
+                        expiration_time=get_dht_time() + self.statistics_expiration,
                         return_future=True,
                     )
 
-        self.samples = self.collaborative_optimizer.local_samples_accumulated
+        self.samples = local_progress.samples_accumulated
 
         return control
 
@@ -170,19 +150,17 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
     @torch.no_grad()
     def backup_state(self) -> bytes:
-        return pickle.dumps(
-            {"model": self.model.state_dict(), "optimizer": self.collaborative_optimizer.opt.state_dict()}
-        )
+        return pickle.dumps({"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()})
 
     @torch.no_grad()
     def restore_from_backup(self, backup: bytes):
         state = pickle.loads(backup)
         self.model.load_state_dict(state["model"])
-        self.collaborative_optimizer.opt.load_state_dict(state["optimizer"])
+        self.optimizer.load_state_dict(state["optimizer"])
 
 
 class NoOpScheduler(LRSchedulerBase):
-    """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
+    """Dummy scheduler for transformers.Trainer. The real scheduler is defined in Optimizer.scheduler"""
 
     def get_lr(self):
         return [group["lr"] for group in self.optimizer.param_groups]
@@ -202,12 +180,17 @@ class NoOpScheduler(LRSchedulerBase):
 
 
 def main():
-    parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments, AveragerArguments))
-    training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
-
+    parser = HfArgumentParser(
+        (
+            AlbertTrainingArguments,
+            DatasetArguments,
+            CollaborationArguments,
+            AveragerArguments,
+            ProgressTrackerArguments,
+        )
+    )
+    training_args, dataset_args, collaboration_args, averager_args, tracker_args = parser.parse_args_into_dataclasses()
     logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
-    if len(collaboration_args.initial_peers) == 0:
-        raise ValueError("Please specify at least one network endpoint in initial peers.")
 
     setup_transformers_logging(training_args.local_rank)
     logger.info(f"Training/evaluation parameters:\n{training_args}")
@@ -216,7 +199,15 @@ def main():
     set_seed(training_args.seed)
 
     config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
-    tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
+    try:
+        tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
+    except OSError:
+        logger.fatal(
+            f"No tokenizer data found in {dataset_args.tokenizer_path}, "
+            f"please run ./tokenize_wikitext103.py before running this"
+        )
+        sys.exit(1)
+
     model = get_model(training_args, config, tokenizer)
     model.to(training_args.device)
 
@@ -224,11 +215,9 @@ def main():
     # This data collator will take care of randomly masking the tokens.
     data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
 
-    opt, scheduler = get_optimizer_and_scheduler(training_args, model)
-
     validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
 
-    dht = hivemind.DHT(
+    dht = DHT(
         start=True,
         initial_peers=collaboration_args.initial_peers,
         client_mode=collaboration_args.client_mode,
@@ -246,19 +235,53 @@ def main():
 
     adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
 
-    collaborative_optimizer = hivemind.CollaborativeOptimizer(
-        opt=opt,
+    # We need to make such a lambda function instead of just an optimizer instance
+    # to make hivemind.Optimizer(..., offload_optimizer=True) work
+    opt = lambda params: Lamb(
+        params,
+        lr=training_args.learning_rate,
+        betas=(training_args.adam_beta1, training_args.adam_beta2),
+        eps=training_args.adam_epsilon,
+        weight_decay=training_args.weight_decay,
+        clamp_value=training_args.clamp_value,
+        debias=True,
+    )
+
+    no_decay = ["bias", "LayerNorm.weight"]
+    params = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+            "weight_decay": training_args.weight_decay,
+        },
+        {
+            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+            "weight_decay": 0.0,
+        },
+    ]
+
+    scheduler = lambda opt: get_linear_schedule_with_warmup(
+        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
+    )
+
+    optimizer = Optimizer(
         dht=dht,
-        scheduler=scheduler,
-        prefix=collaboration_args.experiment_prefix,
-        compression=hivemind.Float16Compression(),
-        batch_size_per_step=total_batch_size_per_step,
-        bandwidth=collaboration_args.bandwidth,
+        run_id=collaboration_args.experiment_prefix,
         target_batch_size=adjusted_target_batch_size,
+        batch_size_per_step=total_batch_size_per_step,
+        optimizer=opt,
+        params=params,
+        scheduler=scheduler,
+        matchmaking_time=collaboration_args.matchmaking_time,
+        averaging_timeout=collaboration_args.averaging_timeout,
+        offload_optimizer=True,
+        delay_optimizer_step=True,
+        delay_grad_averaging=True,
         client_mode=collaboration_args.client_mode,
+        grad_compression=Float16Compression(),
+        state_averaging_compression=Float16Compression(),
+        averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)},
+        tracker_opts=asdict(tracker_args),
         verbose=True,
-        start=True,
-        **asdict(averager_args),
     )
 
     class TrainerWithIndependentShuffling(Trainer):
@@ -274,11 +297,11 @@ def main():
         data_collator=data_collator,
         train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
         eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
-        optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
+        optimizers=(optimizer, NoOpScheduler(optimizer)),
         callbacks=[
             CollaborativeCallback(
                 dht,
-                collaborative_optimizer,
+                optimizer,
                 model,
                 local_public_key,
                 collaboration_args.statistics_expiration,

+ 17 - 20
examples/albert/run_training_monitor.py

@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 
 import time
 from dataclasses import asdict, dataclass, field
@@ -12,10 +12,11 @@ from torch_optimizer import Lamb
 from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 
 import hivemind
+from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 import utils
-from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
+from arguments import AveragerArguments, BaseTrainingArguments, OptimizerArguments
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -55,14 +56,14 @@ class TrainingMonitorArguments(BaseTrainingArguments):
     upload_interval: Optional[float] = field(
         default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
     )
-    store_checkpoins: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
+    store_checkpoints: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
 
 
 class CheckpointHandler:
     def __init__(
         self,
         monitor_args: TrainingMonitorArguments,
-        collab_optimizer_args: CollaborativeOptimizerArguments,
+        optimizer_args: OptimizerArguments,
         averager_args: AveragerArguments,
         dht: hivemind.DHT,
     ):
@@ -95,17 +96,13 @@ class CheckpointHandler:
             debias=True,
         )
 
-        adjusted_target_batch_size = collab_optimizer_args.target_batch_size - collab_optimizer_args.batch_size_lead
-
-        self.collaborative_optimizer = hivemind.CollaborativeOptimizer(
-            opt=opt,
+        self.state_averager = TrainingStateAverager(
             dht=dht,
+            optimizer=opt,
             prefix=experiment_prefix,
-            compression_type=hivemind.Float16Compression(),
-            bandwidth=collab_optimizer_args.bandwidth,
-            target_batch_size=adjusted_target_batch_size,
-            client_mode=collab_optimizer_args.client_mode,
-            verbose=True,
+            state_compression=hivemind.Float16Compression(),
+            bandwidth=optimizer_args.bandwidth,
+            client_mode=optimizer_args.client_mode,
             start=True,
             **asdict(averager_args),
         )
@@ -121,7 +118,7 @@ class CheckpointHandler:
 
     def save_state(self, cur_step):
         logger.info("Saving state from peers")
-        self.collaborative_optimizer.load_state_from_peers()
+        self.state_averager.load_state_from_peers()
         self.previous_step = cur_step
 
     def is_time_to_upload(self):
@@ -134,7 +131,7 @@ class CheckpointHandler:
 
     def upload_checkpoint(self, current_loss):
         logger.info("Saving optimizer")
-        torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
+        torch.save(self.state_averager.optimizer.state_dict(), f"{self.repo_path}/optimizer_state.pt")
         self.previous_timestamp = time.time()
         logger.info("Started uploading to Model Hub")
         self.model.push_to_hub(
@@ -146,8 +143,8 @@ class CheckpointHandler:
 
 
 if __name__ == "__main__":
-    parser = HfArgumentParser((TrainingMonitorArguments, CollaborativeOptimizerArguments, AveragerArguments))
-    monitor_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
+    parser = HfArgumentParser((TrainingMonitorArguments, OptimizerArguments, AveragerArguments))
+    monitor_args, optimizer_args, averager_args = parser.parse_args_into_dataclasses()
 
     if monitor_args.use_google_dns:
         request = requests.get("https://api.ipify.org")
@@ -176,8 +173,8 @@ if __name__ == "__main__":
         wandb.init(project=monitor_args.wandb_project)
 
     current_step = 0
-    if monitor_args.store_checkpoins:
-        checkpoint_handler = CheckpointHandler(monitor_args, collab_optimizer_args, averager_args, dht)
+    if monitor_args.store_checkpoints:
+        checkpoint_handler = CheckpointHandler(monitor_args, optimizer_args, averager_args, dht)
 
     while True:
         metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
@@ -219,7 +216,7 @@ if __name__ == "__main__":
                         }
                     )
 
-                if monitor_args.store_checkpoins:
+                if monitor_args.store_checkpoints:
                     if checkpoint_handler.is_time_to_save_state(current_step):
                         checkpoint_handler.save_state(current_step)
                         if checkpoint_handler.is_time_to_upload():

+ 1 - 1
examples/albert/tokenize_wikitext103.py

@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 """ This script builds a pre-tokenized compressed representation of WikiText-103 using huggingface/datasets """
 import random
 from functools import partial

+ 1 - 1
hivemind/__init__.py

@@ -23,4 +23,4 @@ from hivemind.optim import (
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 
-__version__ = "1.0.0dev0"
+__version__ = "1.1.0dev0"

+ 9 - 5
hivemind/averaging/allreduce.py

@@ -4,7 +4,7 @@ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Typ
 
 import torch
 
-from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
+from hivemind.averaging.partition import AllreduceException, BannedException, TensorPartContainer, TensorPartReducer
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
@@ -343,10 +343,14 @@ class AllReduceRunner(ServicerBase):
                 stream,
                 max_prefetch=self.tensor_part_container.prefetch,
             ):
-                averaged_part = await self.tensor_part_reducer.accumulate_part(
-                    sender_index, part_index, tensor_part, weight=weight
-                )
-                part_index += 1
+                try:
+                    averaged_part = await self.tensor_part_reducer.accumulate_part(
+                        sender_index, part_index, tensor_part, weight=weight
+                    )
+                    part_index += 1
+                except BannedException:
+                    logger.debug(f"Sender {sender_index} is already banned")
+                    break  # sender was banned, we no longer need to aggregate it
 
                 serialized_delta = await loop.run_in_executor(
                     None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)

+ 2 - 4
hivemind/averaging/matchmaking.py

@@ -9,8 +9,6 @@ import random
 from math import isfinite
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
-import numpy as np
-
 from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
@@ -203,7 +201,7 @@ class Matchmaking:
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
                 if message.code == averaging_pb2.ACCEPTED:
-                    logger.debug(f"{self.peer_id} - joining the group of {leader}; waiting for peers")
+                    logger.debug(f"{self.peer_id} - joining the group of {leader}, waiting for peers")
                     self.current_leader = leader
                     self.was_accepted_to_group.set()
                     if len(self.current_followers) > 0:
@@ -242,7 +240,7 @@ class Matchmaking:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             return None
         except (P2PHandlerError, ControlFailure, DispatchFailure, StopAsyncIteration) as e:
-            logger.debug(f"{self} - failed to request potential leader {leader}:")
+            logger.debug(f"{self} - failed to request potential leader {leader}:", exc_info=True)
             return None
 
         finally:

+ 9 - 0
hivemind/averaging/partition.py

@@ -227,6 +227,9 @@ class TensorPartReducer:
             await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
             if self.finished.is_set():
                 raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")
+
+        if self.sender_failed_after[sender_index] != float("inf"):
+            raise BannedException(f"sender {sender_index} was banned in background")
         assert part_index == self.current_part_index
 
         current_part_future = self.current_part_future
@@ -241,6 +244,8 @@ class TensorPartReducer:
     def on_sender_failed(self, sender_index: int):
         """Exclude that sender's data for averaging any parts that it did not submit yet."""
         self.sender_failed_after[sender_index] = self.num_parts_received[sender_index]
+        if self.finished.is_set():
+            return
         if self.current_part_index == self.num_parts_received[sender_index]:
             self.num_current_senders -= 1
             self.check_current_part_finished()
@@ -270,3 +275,7 @@ class TensorPartReducer:
 
 class AllreduceException(Exception):
     """A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error)"""
+
+
+class BannedException(AllreduceException):
+    """An exception that indicates that a given sender was banned and will no longer be aggregated"""

+ 1 - 1
hivemind/hivemind_cli/run_server.py

@@ -4,7 +4,7 @@ from pathlib import Path
 import configargparse
 import torch
 
-from hivemind.moe.server import Server
+from hivemind.moe import Server
 from hivemind.moe.server.layers import schedule_name_to_scheduler
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit

+ 8 - 1
hivemind/moe/__init__.py

@@ -1,2 +1,9 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class
+from hivemind.moe.server import (
+    ExpertBackend,
+    Server,
+    background_server,
+    declare_experts,
+    get_experts,
+    register_expert_class,
+)

+ 4 - 4
hivemind/moe/client/moe.py

@@ -9,8 +9,8 @@ import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
@@ -48,7 +48,7 @@ class RemoteMixtureOfExperts(nn.Module):
         *,
         in_features,
         grid_size: Tuple[int, ...],
-        dht: hivemind.DHT,
+        dht: DHT,
         uid_prefix: str,
         k_best: int,
         k_min: int = 1,
@@ -245,7 +245,7 @@ class _RemoteCallMany(torch.autograd.Function):
         else:
             outputs_schema = info["outputs_schema"]
         outputs = nested_map(
-            lambda descriptor: descriptor.make_empty(num_samples, max_experts, device=flat_inputs[0].device).zero_(),
+            lambda descriptor: descriptor.make_zeros(num_samples, max_experts, device=flat_inputs[0].device),
             outputs_schema,
         )
 
@@ -341,7 +341,7 @@ class _RemoteCallMany(torch.autograd.Function):
         # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
 
         grad_inputs = nested_map(
-            lambda descr: descr.make_empty(num_samples, device=flat_grad_outputs[0].device).zero_(),
+            lambda descr: descr.make_zeros(num_samples, device=flat_grad_outputs[0].device),
             list(nested_flatten(info["forward_schema"])),
         )
 

+ 3 - 355
hivemind/moe/server/__init__.py

@@ -1,356 +1,4 @@
-from __future__ import annotations
-
-import multiprocessing as mp
-import multiprocessing.synchronize
-import threading
-from contextlib import contextmanager
-from functools import partial
-from pathlib import Path
-from typing import Dict, List, Optional, Tuple
-
-import torch
-from multiaddr import Multiaddr
-
-import hivemind
-from hivemind.dht import DHT
-from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
-from hivemind.moe.server.connection_handler import ConnectionHandler
-from hivemind.moe.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
+from hivemind.moe.server.dht_handler import declare_experts, get_experts
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
-from hivemind.moe.server.layers import (
-    add_custom_models_from_file,
-    name_to_block,
-    name_to_input,
-    register_expert_class,
-    schedule_name_to_scheduler,
-)
-from hivemind.moe.server.runtime import Runtime
-from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
-
-logger = get_logger(__name__)
-
-
-class Server(threading.Thread):
-    """
-    Server allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
-    After creation, a server should be started: see Server.run or Server.run_in_background.
-
-    A working server does 3 things:
-     - processes incoming forward/backward requests via Runtime (created by the server)
-     - publishes updates to expert status every :update_period: seconds
-     - follows orders from HivemindController - if it exists
-
-    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
-     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
-    :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
-    :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
-    :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
-        if too small for normal functioning, we recommend 4 handlers per expert backend.
-    :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
-        if dht is None, this parameter is ignored.
-    :param start: if True, the server will immediately start as a background thread and returns control after server
-        is ready (see .ready below)
-    """
-
-    def __init__(
-        self,
-        dht: Optional[DHT],
-        expert_backends: Dict[str, ExpertBackend],
-        listen_on: Endpoint = "0.0.0.0:*",
-        num_connection_handlers: int = 1,
-        update_period: int = 30,
-        start=False,
-        checkpoint_dir=None,
-        **kwargs,
-    ):
-        super().__init__()
-        self.dht, self.experts, self.update_period = dht, expert_backends, update_period
-        if get_port(listen_on) is None:
-            listen_on = replace_port(listen_on, new_port=get_free_port())
-        self.listen_on, self.port = listen_on, get_port(listen_on)
-
-        self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
-        if checkpoint_dir is not None:
-            self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
-        else:
-            self.checkpoint_saver = None
-        self.runtime = Runtime(self.experts, **kwargs)
-
-        if self.dht and self.experts:
-            self.dht_handler_thread = DHTHandlerThread(
-                experts=self.experts,
-                dht=self.dht,
-                endpoint=self.listen_on,
-                update_period=self.update_period,
-                daemon=True,
-            )
-
-        if start:
-            self.run_in_background(await_ready=True)
-
-    @classmethod
-    def create(
-        cls,
-        listen_on="0.0.0.0:*",
-        num_experts: int = None,
-        expert_uids: str = None,
-        expert_pattern: str = None,
-        expert_cls="ffn",
-        hidden_dim=1024,
-        optim_cls=torch.optim.Adam,
-        scheduler: str = "none",
-        num_warmup_steps=None,
-        num_total_steps=None,
-        clip_grad_norm=None,
-        num_handlers=None,
-        min_batch_size=1,
-        max_batch_size=4096,
-        device=None,
-        no_dht=False,
-        initial_peers=(),
-        checkpoint_dir: Optional[Path] = None,
-        compression=CompressionType.NONE,
-        stats_report_interval: Optional[int] = None,
-        custom_module_path=None,
-        *,
-        start: bool,
-    ) -> Server:
-        """
-        Instantiate a server with several identical experts. See argparse comments below for details
-        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
-        :param num_experts: run this many identical experts
-        :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
-           means "sample random experts between myprefix.0.0 and myprefix.255.255;
-        :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
-        :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
-        :param hidden_dim: main dimension for expert_cls
-        :param num_handlers: server will use this many parallel processes to handle incoming requests
-        :param min_batch_size: total num examples in the same batch will be greater than this value
-        :param max_batch_size: total num examples in the same batch will not exceed this value
-        :param device: all experts will use this device in torch notation; default: cuda if available else cpu
-
-        :param optim_cls: uses this optimizer to train all experts
-        :param scheduler: if not `none`, the name of the expert LR scheduler
-        :param num_warmup_steps: the number of warmup steps for LR schedule
-        :param num_total_steps: the total number of steps for LR schedule
-        :param clip_grad_norm: maximum gradient norm used for clipping
-
-        :param no_dht: if specified, the server will not be attached to a dht
-        :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
-
-        :param checkpoint_dir: directory to save and load expert checkpoints
-
-        :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
-            hosted on this server. For a more fine-grained compression, start server in python and specify compression
-            for each BatchTensorProto in ExpertBackend for the respective experts.
-
-        :param start: if True, starts server right away and returns when server is ready for requests
-        :param stats_report_interval: interval between two reports of batch processing performance statistics
-        """
-        if custom_module_path is not None:
-            add_custom_models_from_file(custom_module_path)
-        assert expert_cls in name_to_block
-
-        if no_dht:
-            dht = None
-        else:
-            dht = hivemind.DHT(initial_peers=initial_peers, start=True)
-            visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
-            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
-
-        assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
-            num_experts is not None and expert_uids is None
-        ), "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
-
-        if expert_uids is None:
-            if checkpoint_dir is not None:
-                assert is_directory(checkpoint_dir)
-                expert_uids = [
-                    child.name for child in checkpoint_dir.iterdir() if (child / "checkpoint_last.pt").exists()
-                ]
-                total_experts_in_checkpoint = len(expert_uids)
-                logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
-
-                if total_experts_in_checkpoint > num_experts:
-                    raise ValueError(
-                        f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
-                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints."
-                    )
-            else:
-                expert_uids = []
-
-            uids_to_generate = num_experts - len(expert_uids)
-            if uids_to_generate > 0:
-                logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
-                expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))
-
-        num_experts = len(expert_uids)
-        num_handlers = num_handlers if num_handlers is not None else num_experts * 8
-        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
-        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-
-        sample_input = name_to_input[expert_cls](3, hidden_dim)
-        if isinstance(sample_input, tuple):
-            args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
-        else:
-            args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
-
-        scheduler = schedule_name_to_scheduler[scheduler]
-
-        # initialize experts
-        experts = {}
-        for expert_uid in expert_uids:
-            expert = name_to_block[expert_cls](hidden_dim)
-            experts[expert_uid] = hivemind.ExpertBackend(
-                name=expert_uid,
-                expert=expert,
-                args_schema=args_schema,
-                optimizer=optim_cls(expert.parameters()),
-                scheduler=scheduler,
-                num_warmup_steps=num_warmup_steps,
-                num_total_steps=num_total_steps,
-                clip_grad_norm=clip_grad_norm,
-                min_batch_size=min_batch_size,
-                max_batch_size=max_batch_size,
-            )
-
-        if checkpoint_dir is not None:
-            load_experts(experts, checkpoint_dir)
-
-        return cls(
-            dht,
-            experts,
-            listen_on=listen_on,
-            num_connection_handlers=num_handlers,
-            device=device,
-            checkpoint_dir=checkpoint_dir,
-            stats_report_interval=stats_report_interval,
-            start=start,
-        )
-
-    def run(self):
-        """
-        Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
-        runs Runtime (self.runtime) to process incoming requests.
-        """
-        logger.info(f"Server started at {self.listen_on}")
-        logger.info(f"Got {len(self.experts)} experts:")
-        for expert_name, backend in self.experts.items():
-            num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
-            logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
-
-        if self.dht:
-            if not self.dht.is_alive():
-                self.dht.run_in_background(await_ready=True)
-
-            if self.experts:
-                self.dht_handler_thread.start()
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.start()
-
-        for process in self.conn_handlers:
-            if not process.is_alive():
-                process.start()
-            process.ready.wait()
-
-        try:
-            self.runtime.run()
-        finally:
-            self.shutdown()
-
-    def run_in_background(self, await_ready=True, timeout=None):
-        """
-        Starts Server in a background thread. if await_ready, this method will wait until background server
-        is ready to process incoming requests or for :timeout: seconds max.
-        """
-        self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
-
-    @property
-    def ready(self) -> mp.synchronize.Event:
-        """
-        An event (multiprocessing.Event) that is set when the server is ready to process requests.
-
-        Example
-        =======
-        >>> server.start()
-        >>> server.ready.wait(timeout=10)
-        >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
-        """
-        return self.runtime.ready  # mp.Event that is true if self is ready to process batches
-
-    def shutdown(self):
-        """
-        Gracefully terminate the server, process-safe.
-        Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
-        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
-        """
-        self.ready.clear()
-
-        for process in self.conn_handlers:
-            process.terminate()
-            process.join()
-        logger.debug("Connection handlers terminated")
-
-        if self.dht and self.experts:
-            self.dht_handler_thread.stop.set()
-            self.dht_handler_thread.join()
-
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.stop.set()
-            self.checkpoint_saver.join()
-
-        if self.dht is not None:
-            self.dht.shutdown()
-            self.dht.join()
-
-        logger.debug(f"Shutting down runtime")
-
-        self.runtime.shutdown()
-        logger.info("Server shutdown succesfully")
-
-
-@contextmanager
-def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, List[Multiaddr]]:
-    """A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit"""
-    pipe, runners_pipe = mp.Pipe(duplex=True)
-    runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
-    try:
-        runner.start()
-        # once the server is ready, runner will send us
-        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
-        start_ok, data = pipe.recv()
-        if start_ok:
-            yield data
-            pipe.send("SHUTDOWN")  # on exit from context, send shutdown signal
-        else:
-            raise RuntimeError(f"Server failed to start: {data}")
-    finally:
-        runner.join(timeout=shutdown_timeout)
-        if runner.is_alive():
-            logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
-            runner.kill()
-            logger.info("Server terminated")
-
-
-def _server_runner(pipe, *args, **kwargs):
-    try:
-        server = Server.create(*args, start=True, **kwargs)
-    except Exception as e:
-        logger.exception(f"Encountered an exception when starting a server: {e}")
-        pipe.send((False, f"{type(e).__name__} {e}"))
-        return
-
-    try:
-        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
-        pipe.send((True, (server.listen_on, dht_maddrs)))
-        pipe.recv()  # wait for shutdown signal
-
-    finally:
-        logger.info("Shutting down server...")
-        server.shutdown()
-        server.join()
-        logger.info("Server shut down")
+from hivemind.moe.server.layers import register_expert_class
+from hivemind.moe.server.server import Server, background_server

+ 2 - 2
hivemind/moe/server/expert_backend.py

@@ -74,8 +74,8 @@ class ExpertBackend:
 
         if outputs_schema is None:
             # run expert once to get outputs schema
-            dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema)
-            dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
+            dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema)
+            dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
             dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
             outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
 

+ 2 - 73
hivemind/moe/server/expert_uid.py

@@ -1,12 +1,7 @@
-import random
 import re
-from typing import List, NamedTuple, Optional, Tuple, Union
+from typing import NamedTuple, Tuple, Union
 
-import hivemind
-from hivemind.dht import DHT
-from hivemind.utils import Endpoint, get_logger
-
-logger = get_logger(__name__)
+from hivemind.utils import Endpoint
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
 UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("endpoint", Endpoint)])
@@ -32,69 +27,3 @@ def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPref
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
     pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1
     return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
-
-
-def generate_uids_from_pattern(
-    num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
-) -> List[str]:
-    """
-    Sample experts from a given pattern, remove duplicates.
-    :param num_experts: sample this many unique expert uids
-    :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
-     means "sample random experts between myprefix.0.0 and myprefix.255.255;
-    :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
-    :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
-    :note: this method is not strictly process-safe. If several servers run it concurrently, they have
-     a small chance of sampling duplicate expert uids.
-    """
-    remaining_attempts = attempts_per_expert * num_experts
-    found_uids, attempted_uids = list(), set()
-
-    def _generate_uid():
-        if expert_pattern is None:
-            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
-
-        uid = []
-        for block in expert_pattern.split(UID_DELIMITER):
-            try:
-                if "[" not in block and "]" not in block:
-                    uid.append(block)
-                elif block.startswith("[") and block.endswith("]") and ":" in block:
-                    slice_start, slice_end = map(int, block[1:-1].split(":"))
-                    uid.append(str(random.randint(slice_start, slice_end - 1)))
-                else:
-                    raise ValueError("Block must be either fixed or a range [from:to]")
-            except KeyboardInterrupt:
-                raise
-            except Exception as e:
-                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
-        return UID_DELIMITER.join(uid)
-
-    while remaining_attempts > 0 and len(found_uids) < num_experts:
-
-        # 1. sample new expert uids at random
-        new_uids = []
-        while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
-            new_uid = _generate_uid()
-            remaining_attempts -= 1
-            if new_uid not in attempted_uids:
-                attempted_uids.add(new_uid)
-                new_uids.append(new_uid)
-
-        # 2. look into DHT (if given) and remove duplicates
-        if dht:
-            existing_expert_uids = {
-                found_expert.uid
-                for found_expert in hivemind.moe.server.get_experts(dht, new_uids)
-                if found_expert is not None
-            }
-            new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
-
-        found_uids += new_uids
-
-    if len(found_uids) != num_experts:
-        logger.warning(
-            f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
-            f"{attempts_per_expert * num_experts} attempts"
-        )
-    return found_uids

+ 419 - 0
hivemind/moe/server/server.py

@@ -0,0 +1,419 @@
+from __future__ import annotations
+
+import multiprocessing as mp
+import random
+import threading
+from contextlib import contextmanager
+from functools import partial
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from multiaddr import Multiaddr
+
+from hivemind.dht import DHT
+from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
+from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts
+from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.expert_uid import UID_DELIMITER
+from hivemind.moe.server.layers import (
+    add_custom_models_from_file,
+    name_to_block,
+    name_to_input,
+    schedule_name_to_scheduler,
+)
+from hivemind.moe.server.runtime import Runtime
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.logging import get_logger
+from hivemind.utils.networking import Endpoint, get_free_port, get_port, replace_port
+from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
+
+logger = get_logger(__name__)
+
+
+class Server(threading.Thread):
+    """
+    Server allows you to host "experts" - pytorch subnetworks used by Decentralized Mixture of Experts.
+    After creation, a server should be started: see Server.run or Server.run_in_background.
+
+    A working server does two things:
+     - processes incoming forward/backward requests via Runtime (created by the server)
+     - publishes updates to expert status every :update_period: seconds
+
+    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
+     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
+    :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
+    :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
+    :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
+        if too small for normal functioning, we recommend 4 handlers per expert backend.
+    :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
+        if dht is None, this parameter is ignored.
+    :param start: if True, the server will immediately start as a background thread and returns control after server
+        is ready (see .ready below)
+    """
+
+    def __init__(
+        self,
+        dht: Optional[DHT],
+        expert_backends: Dict[str, ExpertBackend],
+        listen_on: Endpoint = "0.0.0.0:*",
+        num_connection_handlers: int = 1,
+        update_period: int = 30,
+        start=False,
+        checkpoint_dir=None,
+        **kwargs,
+    ):
+        super().__init__()
+        self.dht, self.experts, self.update_period = dht, expert_backends, update_period
+        if get_port(listen_on) is None:
+            listen_on = replace_port(listen_on, new_port=get_free_port())
+        self.listen_on, self.port = listen_on, get_port(listen_on)
+
+        self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
+        if checkpoint_dir is not None:
+            self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
+        else:
+            self.checkpoint_saver = None
+        self.runtime = Runtime(self.experts, **kwargs)
+
+        if self.dht and self.experts:
+            self.dht_handler_thread = DHTHandlerThread(
+                experts=self.experts,
+                dht=self.dht,
+                endpoint=self.listen_on,
+                update_period=self.update_period,
+                daemon=True,
+            )
+
+        if start:
+            self.run_in_background(await_ready=True)
+
+    @classmethod
+    def create(
+        cls,
+        listen_on="0.0.0.0:*",
+        num_experts: int = None,
+        expert_uids: str = None,
+        expert_pattern: str = None,
+        expert_cls="ffn",
+        hidden_dim=1024,
+        optim_cls=torch.optim.Adam,
+        scheduler: str = "none",
+        num_warmup_steps=None,
+        num_total_steps=None,
+        clip_grad_norm=None,
+        num_handlers=None,
+        min_batch_size=1,
+        max_batch_size=4096,
+        device=None,
+        no_dht=False,
+        initial_peers=(),
+        checkpoint_dir: Optional[Path] = None,
+        compression=CompressionType.NONE,
+        stats_report_interval: Optional[int] = None,
+        custom_module_path=None,
+        *,
+        start: bool,
+    ) -> Server:
+        """
+        Instantiate a server with several identical experts. See argparse comments below for details
+        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
+        :param num_experts: run this many identical experts
+        :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
+           means "sample random experts between myprefix.0.0 and myprefix.255.255;
+        :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
+        :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
+        :param hidden_dim: main dimension for expert_cls
+        :param num_handlers: server will use this many parallel processes to handle incoming requests
+        :param min_batch_size: total num examples in the same batch will be greater than this value
+        :param max_batch_size: total num examples in the same batch will not exceed this value
+        :param device: all experts will use this device in torch notation; default: cuda if available else cpu
+
+        :param optim_cls: uses this optimizer to train all experts
+        :param scheduler: if not `none`, the name of the expert LR scheduler
+        :param num_warmup_steps: the number of warmup steps for LR schedule
+        :param num_total_steps: the total number of steps for LR schedule
+        :param clip_grad_norm: maximum gradient norm used for clipping
+
+        :param no_dht: if specified, the server will not be attached to a dht
+        :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
+
+        :param checkpoint_dir: directory to save and load expert checkpoints
+
+        :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
+            hosted on this server. For a more fine-grained compression, start server in python and specify compression
+            for each BatchTensorProto in ExpertBackend for the respective experts.
+
+        :param start: if True, starts server right away and returns when server is ready for requests
+        :param stats_report_interval: interval between two reports of batch processing performance statistics
+        """
+        if custom_module_path is not None:
+            add_custom_models_from_file(custom_module_path)
+        assert expert_cls in name_to_block
+
+        if no_dht:
+            dht = None
+        else:
+            dht = DHT(initial_peers=initial_peers, start=True)
+            visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
+            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+
+        assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
+            num_experts is not None and expert_uids is None
+        ), "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
+
+        if expert_uids is None:
+            if checkpoint_dir is not None:
+                assert is_directory(checkpoint_dir)
+                expert_uids = [
+                    child.name for child in checkpoint_dir.iterdir() if (child / "checkpoint_last.pt").exists()
+                ]
+                total_experts_in_checkpoint = len(expert_uids)
+                logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
+
+                if total_experts_in_checkpoint > num_experts:
+                    raise ValueError(
+                        f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
+                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints."
+                    )
+            else:
+                expert_uids = []
+
+            uids_to_generate = num_experts - len(expert_uids)
+            if uids_to_generate > 0:
+                logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
+                expert_uids.extend(_generate_uids(uids_to_generate, expert_pattern, dht))
+
+        num_experts = len(expert_uids)
+        num_handlers = num_handlers if num_handlers is not None else num_experts * 8
+        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
+        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+
+        sample_input = name_to_input[expert_cls](DUMMY_BATCH_SIZE, hidden_dim)
+        if isinstance(sample_input, tuple):
+            args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
+        else:
+            args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
+
+        scheduler = schedule_name_to_scheduler[scheduler]
+
+        # initialize experts
+        experts = {}
+        for expert_uid in expert_uids:
+            expert = name_to_block[expert_cls](hidden_dim)
+            experts[expert_uid] = ExpertBackend(
+                name=expert_uid,
+                expert=expert,
+                args_schema=args_schema,
+                optimizer=optim_cls(expert.parameters()),
+                scheduler=scheduler,
+                num_warmup_steps=num_warmup_steps,
+                num_total_steps=num_total_steps,
+                clip_grad_norm=clip_grad_norm,
+                min_batch_size=min_batch_size,
+                max_batch_size=max_batch_size,
+            )
+
+        if checkpoint_dir is not None:
+            load_experts(experts, checkpoint_dir)
+
+        return cls(
+            dht,
+            experts,
+            listen_on=listen_on,
+            num_connection_handlers=num_handlers,
+            device=device,
+            checkpoint_dir=checkpoint_dir,
+            stats_report_interval=stats_report_interval,
+            start=start,
+        )
+
+    def run(self):
+        """
+        Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
+        runs Runtime (self.runtime) to process incoming requests.
+        """
+        logger.info(f"Server started at {self.listen_on}")
+        logger.info(f"Got {len(self.experts)} experts:")
+        for expert_name, backend in self.experts.items():
+            num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
+            logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
+
+        if self.dht:
+            if not self.dht.is_alive():
+                self.dht.run_in_background(await_ready=True)
+
+            if self.experts:
+                self.dht_handler_thread.start()
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.start()
+
+        for process in self.conn_handlers:
+            if not process.is_alive():
+                process.start()
+            process.ready.wait()
+
+        try:
+            self.runtime.run()
+        finally:
+            self.shutdown()
+
+    def run_in_background(self, await_ready=True, timeout=None):
+        """
+        Starts Server in a background thread. if await_ready, this method will wait until background server
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready and not self.ready.wait(timeout=timeout):
+            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
+
+    @property
+    def ready(self) -> mp.synchronize.Event:
+        """
+        An event (multiprocessing.Event) that is set when the server is ready to process requests.
+
+        Example
+        =======
+        >>> server.start()
+        >>> server.ready.wait(timeout=10)
+        >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
+        """
+        return self.runtime.ready  # mp.Event that is true if self is ready to process batches
+
+    def shutdown(self):
+        """
+        Gracefully terminate the server, process-safe.
+        Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
+        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
+        """
+        self.ready.clear()
+
+        for process in self.conn_handlers:
+            process.terminate()
+            process.join()
+        logger.debug("Connection handlers terminated")
+
+        if self.dht and self.experts:
+            self.dht_handler_thread.stop.set()
+            self.dht_handler_thread.join()
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.stop.set()
+            self.checkpoint_saver.join()
+
+        if self.dht is not None:
+            self.dht.shutdown()
+            self.dht.join()
+
+        logger.debug(f"Shutting down runtime")
+
+        self.runtime.shutdown()
+        logger.info("Server shutdown succesfully")
+
+
+@contextmanager
+def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[Endpoint, List[Multiaddr]]:
+    """A context manager that creates server in a background process, awaits .ready on entry and shuts down on exit"""
+    pipe, runners_pipe = mp.Pipe(duplex=True)
+    runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
+    try:
+        runner.start()
+        # once the server is ready, runner will send us
+        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
+        start_ok, data = pipe.recv()
+        if start_ok:
+            yield data
+            pipe.send("SHUTDOWN")  # on exit from context, send shutdown signal
+        else:
+            raise RuntimeError(f"Server failed to start: {data}")
+    finally:
+        runner.join(timeout=shutdown_timeout)
+        if runner.is_alive():
+            logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
+            runner.kill()
+            logger.info("Server terminated")
+
+
+def _server_runner(pipe, *args, **kwargs):
+    try:
+        server = Server.create(*args, start=True, **kwargs)
+    except Exception as e:
+        logger.exception(f"Encountered an exception when starting a server: {e}")
+        pipe.send((False, f"{type(e).__name__} {e}"))
+        return
+
+    try:
+        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
+        pipe.send((True, (server.listen_on, dht_maddrs)))
+        pipe.recv()  # wait for shutdown signal
+
+    finally:
+        logger.info("Shutting down server...")
+        server.shutdown()
+        server.join()
+        logger.info("Server shut down")
+
+
+def _generate_uids(
+    num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
+) -> List[str]:
+    """
+    Sample experts from a given pattern, remove duplicates.
+    :param num_experts: sample this many unique expert uids
+    :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
+     means "sample random experts between myprefix.0.0 and myprefix.255.255;
+    :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
+    :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
+    :note: this method is not strictly process-safe. If several servers run it concurrently, they have
+     a small chance of sampling duplicate expert uids.
+    """
+    remaining_attempts = attempts_per_expert * num_experts
+    found_uids, attempted_uids = list(), set()
+
+    def _generate_uid():
+        if expert_pattern is None:
+            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
+
+        uid = []
+        for block in expert_pattern.split(UID_DELIMITER):
+            try:
+                if "[" not in block and "]" not in block:
+                    uid.append(block)
+                elif block.startswith("[") and block.endswith("]") and ":" in block:
+                    slice_start, slice_end = map(int, block[1:-1].split(":"))
+                    uid.append(str(random.randint(slice_start, slice_end - 1)))
+                else:
+                    raise ValueError("Block must be either fixed or a range [from:to]")
+            except KeyboardInterrupt:
+                raise
+            except Exception as e:
+                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
+        return UID_DELIMITER.join(uid)
+
+    while remaining_attempts > 0 and len(found_uids) < num_experts:
+
+        # 1. sample new expert uids at random
+        new_uids = []
+        while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
+            new_uid = _generate_uid()
+            remaining_attempts -= 1
+            if new_uid not in attempted_uids:
+                attempted_uids.add(new_uid)
+                new_uids.append(new_uid)
+
+        # 2. look into DHT (if given) and remove duplicates
+        if dht is not None:
+            existing_expert_uids = {
+                found_expert.uid for found_expert in get_experts(dht, new_uids) if found_expert is not None
+            }
+            new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
+
+        found_uids += new_uids
+
+    if len(found_uids) != num_experts:
+        logger.warning(
+            f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
+            f"{attempts_per_expert * num_experts} attempts"
+        )
+    return found_uids

+ 1 - 1
hivemind/optim/__init__.py

@@ -1,7 +1,7 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.experimental.optimizer import Optimizer
 from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler
+from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.training_averager import TrainingAverager

+ 8 - 0
hivemind/optim/base.py

@@ -1,3 +1,5 @@
+from warnings import warn
+
 import torch
 
 from hivemind.dht import DHT
@@ -8,6 +10,12 @@ class DecentralizedOptimizerBase(torch.optim.Optimizer):
 
     def __init__(self, opt: torch.optim.Optimizer, dht: DHT):
         self.opt, self.dht = opt, dht
+        warn(
+            "DecentralizedOptimizerBase and its subclasses have been deprecated and will be removed "
+            "in hivemind 1.1.0. Use hivemind.Optimizer instead",
+            FutureWarning,
+            stacklevel=2,
+        )
 
     @property
     def state(self):

+ 5 - 5
hivemind/optim/collaborative.py

@@ -57,15 +57,15 @@ class TrainingProgressSchema(BaseModel):
 
 class CollaborativeOptimizer(DecentralizedOptimizerBase):
     """
-    :note: **For new projects please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that.
-      Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and a many advanced ones.
-      CollaborativeOptimizer will still be supported for a while, but it will be deprecated eventually.
-
-    An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers
+    An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers.
 
     These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
     Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
 
+    :note: **For new projects, please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that.
+      Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and many advanced ones.
+      CollaborativeOptimizer will still be supported for a while, but it will be deprecated in v1.1.0.
+
     :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
 
       * calling .step will periodically zero-out gradients w.r.t. model parameters after each step

+ 0 - 0
hivemind/optim/experimental/__init__.py


+ 0 - 0
hivemind/optim/experimental/grad_averager.py → hivemind/optim/grad_averager.py


+ 16 - 5
hivemind/optim/grad_scaler.py

@@ -35,6 +35,7 @@ class GradScaler(TorchGradScaler):
         super().__init__(*args, **kwargs)
         self._is_running_global_step = False
         self._is_ready_to_update = False
+        self._inner_optimizer_states = {}
         self._optimizer_states_to_reset = set()
         self._lock = threading.RLock()
 
@@ -52,7 +53,12 @@ class GradScaler(TorchGradScaler):
             assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
             if self._is_running_global_step:
                 super().unscale_(optimizer)
-                self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
+                self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
+                # note: we store unscaled optimizer state in a separate dict and not in _per_optimizer_states in order
+                # to avoid an edge case where full DPU peer encounters overflow in local gradients while averaging
+                # offloaded gradients (i.e. after global unscale but before global step). Due to overflow, next call to
+                # .update on user side would reset *all* optimizer states and cause .step to unscale gradients twice.
+                # Offloaded optimizer is not affected by overflow in on-device gradients and should not be reset.
                 return True
             else:
                 self._check_inf_per_device(optimizer)
@@ -62,14 +68,19 @@ class GradScaler(TorchGradScaler):
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step and not isinstance(optimizer, hivemind.Optimizer):
             # ^-- invoked privately within hivemind optimizer
+            inner_optimizer = optimizer
             with self._lock:
                 if self._is_ready_to_update:
                     logger.warning("Please call grad_scaler.update() after each step")
+
+                inner_optimizer_state = self._inner_optimizer_states.pop(id(inner_optimizer), None)
+                if inner_optimizer_state is not None:
+                    self._per_optimizer_states[id(inner_optimizer)] = inner_optimizer_state
                 assert (
-                    self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED
-                ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
-                if self.are_grads_finite(optimizer, use_cached=True):
-                    super().step(optimizer, *args, **kwargs)
+                    self._per_optimizer_states[id(inner_optimizer)]["stage"] == OptState.UNSCALED
+                ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step"
+                if self.are_grads_finite(inner_optimizer, use_cached=True):
+                    super().step(inner_optimizer, *args, **kwargs)
                 else:
                     logger.warning("Skipping global step due to gradient over/underflow")
                 self._is_ready_to_update = True

+ 23 - 11
hivemind/optim/experimental/optimizer.py → hivemind/optim/optimizer.py

@@ -11,9 +11,10 @@ import torch
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
-from hivemind.optim.experimental.grad_averager import GradientAverager
-from hivemind.optim.experimental.progress_tracker import ProgressTracker
-from hivemind.optim.experimental.state_averager import (
+from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_scaler import GradScaler
+from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
+from hivemind.optim.state_averager import (
     LRSchedulerBase,
     OptimizerFactory,
     Parameters,
@@ -22,7 +23,6 @@ from hivemind.optim.experimental.state_averager import (
     TorchOptimizer,
     TrainingStateAverager,
 )
-from hivemind.optim.grad_scaler import GradScaler
 from hivemind.utils import PerformanceEMA, get_dht_time, get_logger
 
 logger = get_logger(__name__)
@@ -154,7 +154,7 @@ class Optimizer(torch.optim.Optimizer):
 
     :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
     :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
-    :param performance_ema_alpha: moving average alpha  in ProgressTracer, TrainingStateAverager and Optimizer
+    :param performance_ema_alpha: moving average alpha in ProgressTracker, TrainingStateAverager and Optimizer
     :param verbose: if True, report internal events such as accumilating gradients and running background tasks
 
     :note: in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer
@@ -197,6 +197,8 @@ class Optimizer(torch.optim.Optimizer):
         shutdown_timeout: float = 5,
         verbose: bool = False,
     ):
+        self._parent_pid = os.getpid()
+
         client_mode = client_mode if client_mode is None else dht.client_mode
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
         offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
@@ -262,7 +264,6 @@ class Optimizer(torch.optim.Optimizer):
 
         self._should_check_synchronization_on_update = True  # used in self.should_load_state_from_peers
         self._schema_hash = self._compute_schema_hash()
-        self._parent_pid = os.getpid()
 
         self.delay_before_state_averaging = PerformanceEMA(alpha=performance_ema_alpha)
         # measures the average time from the beginning of self._update_global_epoch to the call to state_averager
@@ -345,6 +346,10 @@ class Optimizer(torch.optim.Optimizer):
         """
         return self.state_averager.local_epoch
 
+    @property
+    def local_progress(self) -> LocalTrainingProgress:
+        return self.tracker.local_progress
+
     @property
     def use_local_updates(self) -> bool:
         return self.grad_averager is None
@@ -384,7 +389,7 @@ class Optimizer(torch.optim.Optimizer):
             with torch.enable_grad():
                 loss = closure()
 
-        if not self.auxiliary and self.should_load_state_from_peers():
+        if not self.auxiliary and self._should_load_state_from_peers():
             logger.log(self.status_loglevel, "Peer is out of sync")
             self.load_state_from_peers()
             return loss  # local gradients were computed with out-of-sync parameters, must start over
@@ -519,8 +524,12 @@ class Optimizer(torch.optim.Optimizer):
                 logger.exception(e)
 
         if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
-            logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
-            self._tag_along_with_zero_weight(self.scheduled_grads)
+            if self.tracker.global_progress.num_peers > 1:
+                logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
+                self._tag_along_with_zero_weight(self.scheduled_grads)
+            else:
+                logger.log(self.status_loglevel, f"Skipping pre-scheduled averaging round: there are no other peers")
+                self.scheduled_grads.cancel()
             self.scheduled_grads = None
         return began_averaging_gradients
 
@@ -564,7 +573,6 @@ class Optimizer(torch.optim.Optimizer):
 
         if eta_seconds_to_averaging <= self.matchmaking_time:
             if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
-
                 min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
                 actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
                 logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f} sec")
@@ -626,7 +634,7 @@ class Optimizer(torch.optim.Optimizer):
                 else:
                     param.grad.zero_()
 
-    def should_load_state_from_peers(self) -> bool:
+    def _should_load_state_from_peers(self) -> bool:
         """
         If true, peer will discard local progress and attempt to download state from peers.
         This method allows peer to continue training in two cases:
@@ -646,6 +654,10 @@ class Optimizer(torch.optim.Optimizer):
             return self.local_epoch != self.tracker.global_epoch  # require exact synchronization once per step
         return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
 
+    def is_synchronized_with_peers(self) -> bool:
+        """Checks whether the current peer is up-to-date with others in terms of the epoch (step) number."""
+        return self.local_epoch >= self.tracker.global_epoch - 1
+
     def load_state_from_peers(self, **kwargs):
         """
         Attempt to load the newest collaboration state from other peers within the same run_id.

+ 17 - 12
hivemind/optim/experimental/progress_tracker.py → hivemind/optim/progress_tracker.py

@@ -195,6 +195,7 @@ class ProgressTracker(threading.Thread):
     async def _progress_reporter(self):
         """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
         last_report_time = -float("inf")
+        last_report_epoch = -float("inf")
         store_task = None
         try:
             while not self.shutdown_triggered.is_set():
@@ -209,19 +210,23 @@ class ProgressTracker(threading.Thread):
 
                 local_progress = self.local_progress
                 last_report_time = get_dht_time()
-
-                store_task = asyncio.create_task(
-                    asyncio.wait_for(
-                        self.dht.store(
-                            key=self.training_progress_key,
-                            subkey=self._local_public_key,
-                            value=local_progress.dict(),
-                            expiration_time=last_report_time + self.metadata_expiration,
-                            return_future=True,
-                        ),
-                        timeout=self.metadata_expiration,
+                if local_progress.samples_accumulated > 0:
+                    last_report_epoch = self.global_epoch
+
+                if last_report_epoch >= self.global_epoch - 1:
+                    # report progress if peer is synchronized and actively reporting samples. Do not report aux peers.
+                    store_task = asyncio.create_task(
+                        asyncio.wait_for(
+                            self.dht.store(
+                                key=self.training_progress_key,
+                                subkey=self._local_public_key,
+                                value=local_progress.dict(),
+                                expiration_time=last_report_time + self.metadata_expiration,
+                                return_future=True,
+                            ),
+                            timeout=self.metadata_expiration,
+                        )
                     )
-                )
         finally:
             logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}")
             if store_task is not None:

+ 9 - 0
hivemind/optim/experimental/state_averager.py → hivemind/optim/state_averager.py

@@ -152,6 +152,15 @@ class TrainingStateAverager(DecentralizedAverager):
         parameter_names = tuple(nested_flatten(parameter_names))
         assert len(parameters) == len(parameter_names), f"Expected {len(parameters)} names, got {len(parameter_names)}"
         assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
+        params_with_grad = sum(p.numel() for p in parameters if p.requires_grad)
+        params_no_grad = sum(p.numel() for p in parameters if not p.requires_grad)
+        if params_no_grad >= params_with_grad:
+            logger.warning(
+                "The majority of parameters have requires_grad=False, but they are still synchronized"
+                " with peers. If these parameters are frozen (not updated), please do not feed them into "
+                "the optimizer at all in order to avoid communication overhead. Proceeding anyway."
+            )
+
         return param_groups, parameters, parameter_names
 
     def _make_averaged_parameters(self, main_parameters: Sequence[torch.Tensor]):

+ 5 - 0
hivemind/p2p/p2p_daemon.py

@@ -140,6 +140,11 @@ class P2P:
         socket_uid = secrets.token_urlsafe(8)
         self._daemon_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pd-{socket_uid}.sock")
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
+        if announce_maddrs is not None:
+            for addr in announce_maddrs:
+                addr = Multiaddr(addr)
+                if ("tcp" in addr and addr["tcp"] == "0") or ("udp" in addr and addr["udp"] == "0"):
+                    raise ValueError("Please specify an explicit port in announce_maddrs: port 0 is not supported")
 
         need_bootstrap = bool(initial_peers) or use_ipfs
         process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {"dht": 0})

+ 4 - 4
hivemind/utils/tensor_descr.py

@@ -46,11 +46,11 @@ class TensorDescriptor(DescriptorBase):
             tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, _safe_check_pinned(tensor)
         )
 
-    def make_empty(self, **kwargs):
+    def make_zeros(self, **kwargs):
         properties = asdict(self)
         properties.update(kwargs)
         properties.pop("compression")
-        return torch.empty(**properties)
+        return torch.zeros(**properties)
 
 
 def _str_to_torch_type(name: str, torch_type: type):
@@ -86,9 +86,9 @@ class BatchTensorDescriptor(TensorDescriptor):
             compression=compression if tensor.is_floating_point() else CompressionType.NONE,
         )
 
-    def make_empty(self, *batch_size: int, **kwargs) -> torch.Tensor:
+    def make_zeros(self, *batch_size: int, **kwargs) -> torch.Tensor:
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
-        return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
+        return super().make_zeros(size=(*batch_size, *self.shape[1:]), **kwargs)
 
     def packb(self) -> bytes:
         obj_dict = asdict(self)

+ 31 - 34
tests/test_moe.py

@@ -3,9 +3,12 @@ import numpy as np
 import pytest
 import torch
 
-import hivemind
-from hivemind.moe.client.expert import DUMMY
-from hivemind.moe.server import background_server, declare_experts, layers
+from hivemind.dht import DHT
+from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client.moe import DUMMY, _RemoteCallMany
+from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
+from hivemind.moe.server.layers import name_to_block
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
 @pytest.mark.forked
@@ -16,11 +19,9 @@ def test_moe():
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
     ) as (server_endpoint, dht_maddrs):
-        dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
+        dht = DHT(start=True, initial_peers=dht_maddrs)
 
-        dmoe = hivemind.RemoteMixtureOfExperts(
-            in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn."
-        )
+        dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
 
         for i in range(3):
             out = dmoe(torch.randn(10, 16))
@@ -35,9 +36,9 @@ def test_no_experts():
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
     ) as (server_endpoint, dht_maddrs):
-        dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
+        dht = DHT(start=True, initial_peers=dht_maddrs)
 
-        dmoe = hivemind.RemoteSwitchMixtureOfExperts(
+        dmoe = RemoteSwitchMixtureOfExperts(
             in_features=16,
             grid_size=(4, 4, 4),
             dht=dht,
@@ -74,10 +75,10 @@ def test_call_many(hidden_dim=16):
     ) as (server_endpoint, _):
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
-        e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
-        e5 = hivemind.RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
+        e0, e1, e2, e3, e4 = [RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
+        e5 = RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
 
-        mask, expert_outputs = hivemind.moe.client.moe._RemoteCallMany.apply(
+        mask, expert_outputs = _RemoteCallMany.apply(
             DUMMY,
             [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
             k_min,
@@ -130,8 +131,8 @@ def test_remote_module_call(hidden_dim=16):
         optim_cls=None,
         no_dht=True,
     ) as (server_endpoint, _):
-        real_expert = hivemind.RemoteExpert("expert.0", server_endpoint)
-        fake_expert = hivemind.RemoteExpert("oiasfjiasjf", server_endpoint)
+        real_expert = RemoteExpert("expert.0", server_endpoint)
+        fake_expert = RemoteExpert("oiasfjiasjf", server_endpoint)
 
         out1 = real_expert(torch.randn(1, hidden_dim))
         assert out1.shape == (1, hidden_dim)
@@ -152,12 +153,10 @@ def test_remote_module_call(hidden_dim=16):
 @pytest.mark.forked
 def test_beam_search_correctness():
     all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
-    dht = hivemind.DHT(start=True)
+    dht = DHT(start=True)
     assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
 
-    dmoe = hivemind.RemoteMixtureOfExperts(
-        in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn."
-    )
+    dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
 
     for i in range(25):
         input = torch.randn(32)
@@ -174,7 +173,7 @@ def test_beam_search_correctness():
         # reference: independently find :beam_size: best experts with exhaustive search
         all_scores = dmoe.compute_expert_scores(
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
-            [[hivemind.RemoteExpert(uid, "") for uid in all_expert_uids]],
+            [[RemoteExpert(uid, "") for uid in all_expert_uids]],
         )[0]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
 
@@ -197,7 +196,7 @@ def test_determinism(hidden_dim=16):
         optim_cls=None,
         no_dht=True,
     ) as (server_endpoint, _):
-        expert = hivemind.RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
+        expert = RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
 
         out = expert(xx, mask)
         out_rerun = expert(xx, mask)
@@ -212,8 +211,8 @@ def test_determinism(hidden_dim=16):
 @pytest.mark.forked
 def test_compute_expert_scores():
     try:
-        dht = hivemind.DHT(start=True)
-        moe = hivemind.moe.RemoteMixtureOfExperts(
+        dht = DHT(start=True)
+        moe = RemoteMixtureOfExperts(
             dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1, uid_prefix="expert."
         )
         gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
@@ -221,13 +220,11 @@ def test_compute_expert_scores():
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
             [
-                hivemind.RemoteExpert(
-                    uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337"
-                )
+                RemoteExpert(uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337")
                 for expert_i in range(len(ii[batch_i]))
             ]
             for batch_i in range(len(ii))
-        ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
+        ]  # note: these experts do not exist on server, we use them only to test compute_expert_scores
         logits = moe.compute_expert_scores([gx, gy], batch_experts)
         torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
         assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
@@ -247,25 +244,25 @@ def test_client_anomaly_detection():
 
     experts = {}
     for i in range(4):
-        expert = layers.name_to_block["ffn"](HID_DIM)
-        experts[f"expert.{i}"] = hivemind.ExpertBackend(
+        expert = name_to_block["ffn"](HID_DIM)
+        experts[f"expert.{i}"] = ExpertBackend(
             name=f"expert.{i}",
             expert=expert,
             optimizer=torch.optim.Adam(expert.parameters()),
-            args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
-            outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
+            args_schema=(BatchTensorDescriptor(HID_DIM),),
+            outputs_schema=BatchTensorDescriptor(HID_DIM),
             max_batch_size=16,
         )
 
     experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan")
 
-    dht = hivemind.DHT(start=True)
-    server = hivemind.moe.Server(dht, experts, num_connection_handlers=1)
+    dht = DHT(start=True)
+    server = Server(dht, experts, num_connection_handlers=1)
     server.start()
     try:
         server.ready.wait()
 
-        dmoe = hivemind.RemoteMixtureOfExperts(
+        dmoe = RemoteMixtureOfExperts(
             in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
         )
 
@@ -282,7 +279,7 @@ def test_client_anomaly_detection():
         with pytest.raises(ValueError):
             inf_loss.backward()
 
-        dmoe = hivemind.RemoteMixtureOfExperts(
+        dmoe = RemoteMixtureOfExperts(
             in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
         )
         output = dmoe(input)

+ 4 - 4
tests/test_optimizer.py

@@ -11,10 +11,10 @@ import torch.nn.functional as F
 
 import hivemind
 from hivemind.averaging.control import AveragingStage
-from hivemind.optim.experimental.grad_averager import GradientAverager
-from hivemind.optim.experimental.optimizer import Optimizer
-from hivemind.optim.experimental.progress_tracker import ProgressTracker
-from hivemind.optim.experimental.state_averager import TrainingStateAverager
+from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.optimizer import Optimizer
+from hivemind.optim.progress_tracker import ProgressTracker
+from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.utils.crypto import RSAPrivateKey