Browse Source

Support learning rate schedulers in ExpertBackend (#196)

* Add empty __init__ to hivemind_cli for correct package discovery

* Support learning rate schedulers in ExpertBackend

* Save/load full expert state

* Don't pass compression to make_empty

* spawn -> fork

* Remove load_expert_states

* Make TaskPoolBase an abstract class

* Output warning if some of the keys in state_dict are missing

Co-authored-by: justheuristic <justheuristic@gmail.com>
Max Ryabinin 4 years ago
parent
commit
3024d381c5

+ 6 - 6
.circleci/config.yml

@@ -8,11 +8,11 @@ jobs:
       - checkout
       - restore_cache:
           keys:
-            - v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+            - py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
       - run: pip install -r requirements.txt
       - run: pip install -r requirements-dev.txt
       - save_cache:
-          key: v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+          key: py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
           paths:
             - '~/.cache/pip'
       - run:
@@ -28,11 +28,11 @@ jobs:
       - checkout
       - restore_cache:
           keys:
-            - v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+            - py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
       - run: pip install -r requirements.txt
       - run: pip install -r requirements-dev.txt
       - save_cache:
-          key: v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+          key: py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
           paths:
             - '~/.cache/pip'
       - run:
@@ -48,11 +48,11 @@ jobs:
       - checkout
       - restore_cache:
           keys:
-            - v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+            - py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
       - run: pip install -r requirements.txt
       - run: pip install -r requirements-dev.txt
       - save_cache:
-          key: v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+          key: py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
           paths:
             - '~/.cache/pip'
       - run:

+ 0 - 0
hivemind/hivemind_cli/__init__.py


+ 8 - 1
hivemind/hivemind_cli/run_server.py

@@ -8,6 +8,7 @@ from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.server import Server
 from hivemind.utils.threading import increase_file_limit
 from hivemind.utils.logging import get_logger
+from hivemind.server.layers import schedule_name_to_scheduler
 
 logger = get_logger(__name__)
 
@@ -28,13 +29,20 @@ def main():
     parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
                         help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
     parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
+
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
                         help='server will use this many processes to handle incoming requests')
     parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
                         help='The total number of examples in the same batch will not exceed this value')
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')
+
     parser.add_argument('--optimizer', type=str, default='adam', required=False, help='adam, sgd or none')
+    parser.add_argument('--scheduler', type=str, choices=schedule_name_to_scheduler.keys(), default='none',
+                        help='LR scheduler type to use')
+    parser.add_argument('--num-warmup-steps', type=int, required=False, help='the number of warmup steps for LR schedule')
+    parser.add_argument('--num-training-steps', type=int, required=False, help='the total number of steps for LR schedule')
+
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
                         help='one or more peers that can welcome you to the dht, e.g. 1.2.3.4:1337 192.132.231.4:4321')
@@ -45,7 +53,6 @@ def main():
     parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression '
                         'parameter for grpc. Can be NONE, MEANSTD or FLOAT16')
     parser.add_argument('--checkpoint_dir', type=Path, required=False, help='Directory to store expert checkpoints')
-    parser.add_argument('--load_experts', action='store_true', help='Load experts from the checkpoint directory')
 
     # fmt:on
     args = vars(parser.parse_args())

+ 48 - 94
hivemind/server/__init__.py

@@ -2,23 +2,22 @@ from __future__ import annotations
 
 import multiprocessing as mp
 import multiprocessing.synchronize
-import random
 import threading
 from contextlib import contextmanager
 from functools import partial
-from typing import Dict, Optional, Tuple, List
+from typing import Dict, Optional, Tuple
 from pathlib import Path
 
 import torch
 
 import hivemind
 from hivemind.dht import DHT
-from hivemind.server.expert_uid import UID_DELIMITER
-from hivemind.server.checkpoints import CheckpointSaver, load_weights, dir_is_correct
+from hivemind.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
+from hivemind.server.checkpoints import CheckpointSaver, load_experts, is_directory
 from hivemind.server.connection_handler import ConnectionHandler
 from hivemind.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
 from hivemind.server.expert_backend import ExpertBackend
-from hivemind.server.layers import name_to_block, name_to_input
+from hivemind.server.layers import name_to_block, name_to_input, schedule_name_to_scheduler
 from hivemind.server.runtime import Runtime
 from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
 from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
@@ -68,11 +67,12 @@ class Server(threading.Thread):
         if start:
             self.run_in_background(await_ready=True)
 
-    @staticmethod
-    def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
-               expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
-               device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
-               load_experts=False, compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
+    @classmethod
+    def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
+               expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
+               num_warmup_steps=None, num_training_steps=None, num_handlers=None, max_batch_size=4096, device=None,
+               no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
+               compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
         """
         Instantiate a server with several identical experts. See argparse comments below for details
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -85,7 +85,12 @@ class Server(threading.Thread):
         :param num_handlers: server will use this many parallel processes to handle incoming requests
         :param max_batch_size: total num examples in the same batch will not exceed this value
         :param device: all experts will use this device in torch notation; default: cuda if available else cpu
+
         :param optim_cls: uses this optimizer to train all experts
+        :param scheduler: if not `none`, the name of the expert LR scheduler
+        :param num_warmup_steps: the number of warmup steps for LR schedule
+        :param num_training_steps: the total number of steps for LR schedule
+
         :param no_dht: if specified, the server will not be attached to a dht
         :param initial_peers: a list of peers that will introduce this node to the dht,\
            e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers
@@ -93,8 +98,7 @@ class Server(threading.Thread):
         :param dht_port:  DHT node will listen on this port, default = find open port
            You can then use this node as initial peer for subsequent servers.
 
-        :param checkpoint_dir: directory to save expert checkpoints
-        :param load_experts: whether to load expert checkpoints from checkpoint_dir
+        :param checkpoint_dir: directory to save and load expert checkpoints
 
         :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
             hosted on this server. For a more fine-grained compression, start server in python and specify compression
@@ -113,23 +117,29 @@ class Server(threading.Thread):
             dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
             logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
 
-        if load_experts:
-            assert dir_is_correct(checkpoint_dir)
-            assert expert_uids is None, "Can't both load saved experts and create new ones from given UIDs"
-            expert_uids = [child.name for child in checkpoint_dir.iterdir() if (child / 'checkpoint_last.pt').exists()]
-            if expert_uids:
-                logger.info(f"Located checkpoints for experts {expert_uids}, ignoring UID generation options")
-            else:
-                logger.info(f"No expert checkpoints found in {checkpoint_dir}, generating...")
-
-        assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
-            "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
+        assert ((expert_pattern is None and num_experts is None and expert_uids is not None) or
+                (num_experts is not None and expert_uids is None)), \
+            "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
 
-        # get expert uids if not loaded previously
         if expert_uids is None:
-            assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
-            logger.info(f"Generating expert uids from pattern {expert_pattern}")
-            expert_uids = generate_uids_from_pattern(num_experts, expert_pattern, dht=dht)
+            if checkpoint_dir is not None:
+                assert is_directory(checkpoint_dir)
+                expert_uids = [child.name for child in checkpoint_dir.iterdir() if
+                               (child / 'checkpoint_last.pt').exists()]
+                total_experts_in_checkpoint = len(expert_uids)
+                logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
+
+                if total_experts_in_checkpoint > num_experts:
+                    raise ValueError(
+                        f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
+                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints.")
+            else:
+                expert_uids = []
+
+            uids_to_generate = num_experts - len(expert_uids)
+            if uids_to_generate > 0:
+                logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
+                expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))
 
         num_experts = len(expert_uids)
         num_handlers = num_handlers if num_handlers is not None else num_experts * 8
@@ -142,6 +152,8 @@ class Server(threading.Thread):
         else:
             args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)
 
+        scheduler = schedule_name_to_scheduler[scheduler]
+
         # initialize experts
         experts = {}
         for expert_uid in expert_uids:
@@ -150,15 +162,17 @@ class Server(threading.Thread):
                                                          args_schema=args_schema,
                                                          outputs_schema=hivemind.BatchTensorDescriptor(
                                                              hidden_dim, compression=compression),
-                                                         opt=optim_cls(expert.parameters()),
+                                                         optimizer=optim_cls(expert.parameters()),
+                                                         scheduler=scheduler,
+                                                         num_warmup_steps=num_warmup_steps,
+                                                         num_training_steps=num_training_steps,
                                                          max_batch_size=max_batch_size)
 
-        if load_experts:
-            load_weights(experts, checkpoint_dir)
+        if checkpoint_dir is not None:
+            load_experts(experts, checkpoint_dir)
 
-        server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
-                        start=start)
-        return server
+        return cls(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
+                   checkpoint_dir=checkpoint_dir, start=start)
 
     def run(self):
         """
@@ -241,7 +255,7 @@ class Server(threading.Thread):
 def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     pipe, runners_pipe = mp.Pipe(duplex=True)
-    runner = mp.get_context("spawn").Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
+    runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
 
     try:
         runner.start()
@@ -269,63 +283,3 @@ def _server_runner(pipe, *args, **kwargs):
         server.shutdown()
         server.join()
         logger.info("Server shut down.")
-
-
-def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None,
-                               attempts_per_expert=10) -> List[str]:
-    """
-    Sample experts from a given pattern, remove duplicates.
-    :param num_experts: sample this many unique expert uids
-    :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
-     means "sample random experts between myprefix.0.0 and myprefix.255.255;
-    :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
-    :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
-    :note: this method is not strictly process-safe. If several servers run it concurrently, they have
-     a small chance of sampling duplicate expert uids.
-    """
-    remaining_attempts = attempts_per_expert * num_experts
-    found_uids, attempted_uids = list(), set()
-
-    def _generate_uid():
-        if expert_pattern is None:
-            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
-
-        uid = []
-        for block in expert_pattern.split(UID_DELIMITER):
-            try:
-                if '[' not in block and ']' not in block:
-                    uid.append(block)
-                elif block.startswith('[') and block.endswith(']') and ':' in block:
-                    slice_start, slice_end = map(int, block[1:-1].split(':'))
-                    uid.append(str(random.randint(slice_start, slice_end - 1)))
-                else:
-                    raise ValueError("Block must be either fixed or a range [from:to]")
-            except KeyboardInterrupt as e:
-                raise e
-            except Exception as e:
-                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
-        return UID_DELIMITER.join(uid)
-
-    while remaining_attempts > 0 and len(found_uids) < num_experts:
-
-        # 1. sample new expert uids at random
-        new_uids = []
-        while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
-            new_uid = _generate_uid()
-            remaining_attempts -= 1
-            if new_uid not in attempted_uids:
-                attempted_uids.add(new_uid)
-                new_uids.append(new_uid)
-
-        # 2. look into DHT (if given) and remove duplicates
-        if dht:
-            existing_expert_uids = {found_expert.uid for found_expert in dht.get_experts(new_uids)
-                                    if found_expert is not None}
-            new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
-
-        found_uids += new_uids
-
-    if len(found_uids) != num_experts:
-        logger.warning(f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
-                       f"{attempts_per_expert * num_experts} attempts")
-    return found_uids

+ 15 - 8
hivemind/server/checkpoints.py

@@ -1,17 +1,20 @@
+import os
 import threading
 from datetime import datetime
 from pathlib import Path
 from shutil import copy2
 from tempfile import TemporaryDirectory
 from typing import Dict
-import os
 
 import torch
 
 from hivemind.server.expert_backend import ExpertBackend
+from hivemind.utils.logging import get_logger
 
+logger = get_logger(__name__)
 
-def dir_is_correct(directory: Path):
+
+def is_directory(directory: Path):
     assert directory is not None
     assert directory.exists()
     assert directory.is_dir()
@@ -33,7 +36,7 @@ def copy_tree(src: str, dst: str):
 class CheckpointSaver(threading.Thread):
     def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
         super().__init__()
-        assert dir_is_correct(checkpoint_dir)
+        assert is_directory(checkpoint_dir)
         self.expert_backends = expert_backends
         self.update_period = update_period
         self.checkpoint_dir = checkpoint_dir
@@ -48,21 +51,25 @@ class CheckpointSaver(threading.Thread):
 
 
 def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
-    assert dir_is_correct(checkpoint_dir)
+    logger.debug(f'Storing experts at {checkpoint_dir.absolute()}')
+    assert is_directory(checkpoint_dir)
     timestamp = datetime.now().isoformat(sep='_')
     with TemporaryDirectory() as tmpdirname:
         for expert_name, expert_backend in experts.items():
             expert_dir = Path(tmpdirname) / expert_name
             expert_dir.mkdir()
             checkpoint_name = expert_dir / f'checkpoint_{timestamp}.pt'
-            torch.save(expert_backend.state_dict(), checkpoint_name)
+            torch.save(expert_backend.get_full_state(), checkpoint_name)
             os.symlink(checkpoint_name, expert_dir / 'checkpoint_last.pt')
         copy_tree(tmpdirname, str(checkpoint_dir))
 
 
-def load_weights(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
-    assert dir_is_correct(checkpoint_dir)
+def load_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
+    assert is_directory(checkpoint_dir)
     for expert_name, expert in experts.items():
         checkpoints_folder = checkpoint_dir / expert_name
         latest_checkpoint = checkpoints_folder / 'checkpoint_last.pt'
-        expert.load_state_dict(torch.load(latest_checkpoint))
+        if latest_checkpoint.exists():
+            expert.load_full_state(torch.load(latest_checkpoint))
+        else:
+            logger.warning(f'Failed to load checkpoint for expert {expert_name}')

+ 1 - 1
hivemind/server/connection_handler.py

@@ -16,7 +16,7 @@ from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 logger = get_logger(__name__)
 
 
-class ConnectionHandler(mp.Process):
+class ConnectionHandler(mp.context.ForkProcess):
     """
     A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
 

+ 74 - 12
hivemind/server/expert_backend.py

@@ -1,14 +1,17 @@
-from typing import Dict, Sequence, Any, Tuple, Union
+from typing import Dict, Sequence, Any, Tuple, Union, Callable
 
 import torch
 from torch import nn
 
 from hivemind.server.task_pool import TaskPool
-from hivemind.utils import nested_flatten, nested_pack, nested_compare, nested_map, \
-    BatchTensorDescriptor, DUMMY_BATCH_SIZE
+from hivemind.utils import BatchTensorDescriptor, DUMMY_BATCH_SIZE
+from hivemind.utils.logging import get_logger
+from hivemind.utils.nested import nested_flatten, nested_pack, nested_compare, nested_map
 
+logger = get_logger(__name__)
 
-class ExpertBackend(nn.Module):
+
+class ExpertBackend:
     """
     ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime
     By default, ExpertBackend handles three types of requests:
@@ -26,20 +29,31 @@ class ExpertBackend(nn.Module):
         you should explicitly register these random variables as model inputs or outputs.
         See hivemind.utils.custom_layers.DeterministicDropout for an example
 
-    :param opt: torch optimizer to be applied on every backward call
+    :param optimizer: torch optimizer to be applied on every backward call
+    :param scheduler: a function to create the learning rate scheduler for the expert
     :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
     :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
     :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
+    :param num_warmup_steps: the number of warmup steps for LR schedule
+    :param num_training_steps: the total number of steps for LR schedule
     :param kwargs: extra parameters to be forwarded into TaskPool.__init__
     """
 
-    def __init__(self, name: str, expert: nn.Module, opt: torch.optim.Optimizer, *,
+    def __init__(self, name: str, expert: nn.Module, optimizer: torch.optim.Optimizer, *,
+                 scheduler: Callable = None,
                  args_schema: Tuple[BatchTensorDescriptor, ...] = None,
                  kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
                  outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
+                 num_warmup_steps: int = None, num_training_steps: int = None,
                  **kwargs):
         super().__init__()
-        self.expert, self.opt, self.name = expert, opt, name
+        self.expert, self.optimizer, self.name = expert, optimizer, name
+
+        if scheduler is None:
+            self.scheduler = None
+        else:
+            assert optimizer is not None and num_warmup_steps is not None and num_training_steps is not None
+            self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_training_steps)
 
         self.args_schema = args_schema = tuple(args_schema or ())
         self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
@@ -61,7 +75,8 @@ class ExpertBackend(nn.Module):
         self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
         self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
 
-        self.register_buffer('update_count', torch.zeros(1, dtype=torch.long))
+        self.update_count = 0
+        self.examples_processed = 0
 
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
@@ -111,6 +126,8 @@ class ExpertBackend(nn.Module):
                                   if tensor.is_floating_point() else tensor.detach())
                       for input_key, tensor in kwargs.items()}
 
+            batch_size = args[0].size(0)
+
             outputs = self.expert(*args, **kwargs)
             assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
 
@@ -121,18 +138,63 @@ class ExpertBackend(nn.Module):
                 nested_flatten(grad_outputs), outputs_flat))
             torch.autograd.backward(outputs_flat, grad_tensors=grad_outputs_flat,
                                     create_graph=False, retain_graph=False)
-            self.apply_gradients()
+            self.apply_gradients(batch_size)
 
         return tuple(x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x)
                      for x in nested_flatten((args, kwargs)))
 
-    def apply_gradients(self) -> None:
+    def apply_gradients(self, batch_size) -> None:
         """
         Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
         """
-        self.opt.step()
-        self.opt.zero_grad()
+        self.optimizer.step()
+        self.optimizer.zero_grad()
+
+        if self.scheduler is not None:
+            self.scheduler.step()
+
         self.update_count += 1
+        self.examples_processed += batch_size
+
+    def get_stats(self) -> Dict:
+        """
+        Return current expert training statistics (number of updates, number of processed examples after last optimizer step)
+        """
+        return {
+            'updates': self.update_count,
+            'examples_processed': self.examples_processed
+        }
+
+    def get_full_state(self) -> Dict:
+        """
+        Return the current state of the expert (including batch processing statistics)
+        """
+        full_state = {
+            'stats': self.get_stats(),
+            'model': self.expert.state_dict(),
+            'optimizer': self.optimizer.state_dict(),
+            'scheduler': {} if self.scheduler is None else self.scheduler.state_dict()
+        }
+        return full_state
+
+    def load_full_state(self, state_dict: Dict):
+        if 'stats' in state_dict:
+            self.update_count = state_dict['stats']['updates']
+            self.examples_processed = state_dict['stats']['examples_processed']
+        else:
+            logger.warning(f'Batch processing stats missing for expert {self.name}')
+
+        self.expert.load_state_dict(state_dict['model'])
+
+        if 'optimizer' in state_dict:
+            self.optimizer.load_state_dict(state_dict['optimizer'])
+        else:
+            logger.warning(f'Optimizer state missing for expert {self.name}')
+
+        if self.scheduler is not None and 'scheduler' in state_dict:
+            self.scheduler.load_state_dict(state_dict['scheduler'])
+        else:
+            logger.warning(f'Learning rate scheduler state missing for expert {self.name}')
 
     def get_info(self) -> Dict[str, Any]:
         """ Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration. """

+ 63 - 2
hivemind/server/expert_uid.py

@@ -1,8 +1,11 @@
+import random
 import re
-from typing import NamedTuple, Union, Tuple
+from typing import NamedTuple, Union, Tuple, Optional, List
 
-from hivemind.utils.networking import Endpoint
+from hivemind.dht import DHT
+from hivemind.utils import Endpoint, get_logger
 
+logger = get_logger(__name__)
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
 UidEndpoint = NamedTuple("UidEndpoint", [('uid', ExpertUID), ('endpoint', Endpoint)])
@@ -30,3 +33,61 @@ def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPref
     return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
 
 
+def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None,
+                               attempts_per_expert=10) -> List[str]:
+    """
+    Sample experts from a given pattern, remove duplicates.
+    :param num_experts: sample this many unique expert uids
+    :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
+     means "sample random experts between myprefix.0.0 and myprefix.255.255;
+    :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
+    :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
+    :note: this method is not strictly process-safe. If several servers run it concurrently, they have
+     a small chance of sampling duplicate expert uids.
+    """
+    remaining_attempts = attempts_per_expert * num_experts
+    found_uids, attempted_uids = list(), set()
+
+    def _generate_uid():
+        if expert_pattern is None:
+            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
+
+        uid = []
+        for block in expert_pattern.split(UID_DELIMITER):
+            try:
+                if '[' not in block and ']' not in block:
+                    uid.append(block)
+                elif block.startswith('[') and block.endswith(']') and ':' in block:
+                    slice_start, slice_end = map(int, block[1:-1].split(':'))
+                    uid.append(str(random.randint(slice_start, slice_end - 1)))
+                else:
+                    raise ValueError("Block must be either fixed or a range [from:to]")
+            except KeyboardInterrupt as e:
+                raise e
+            except Exception as e:
+                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
+        return UID_DELIMITER.join(uid)
+
+    while remaining_attempts > 0 and len(found_uids) < num_experts:
+
+        # 1. sample new expert uids at random
+        new_uids = []
+        while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
+            new_uid = _generate_uid()
+            remaining_attempts -= 1
+            if new_uid not in attempted_uids:
+                attempted_uids.add(new_uid)
+                new_uids.append(new_uid)
+
+        # 2. look into DHT (if given) and remove duplicates
+        if dht:
+            existing_expert_uids = {found_expert.uid for found_expert in dht.get_experts(new_uids)
+                                    if found_expert is not None}
+            new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
+
+        found_uids += new_uids
+
+    if len(found_uids) != num_experts:
+        logger.warning(f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
+                       f"{attempts_per_expert * num_experts} attempts")
+    return found_uids

+ 4 - 1
hivemind/server/layers/__init__.py

@@ -2,6 +2,7 @@ import torch
 
 from hivemind.server.layers.common import FeedforwardBlock, TransformerEncoderLayer, NopExpert
 from hivemind.server.layers.dropout import DeterministicDropout, DeterministicDropoutNetwork
+from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
 
 name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim),
                  'transformer': lambda hid_dim: TransformerEncoderLayer(hid_dim, dim_feedforward=4 * hid_dim, nhead=16),
@@ -10,7 +11,9 @@ name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim),
 
 name_to_input = {'ffn': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
                  'transformer': lambda batch_size, hid_dim:
-                 (torch.empty((batch_size, 128, hid_dim)), torch.empty((batch_size, hid_dim), dtype=torch.bool)),
+                 (torch.empty((batch_size, 128, hid_dim)), torch.empty((batch_size, 128), dtype=torch.bool)),
                  'nop': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
                  'det_dropout': lambda batch_size, hid_dim:
                  (torch.empty((batch_size, hid_dim)), torch.randint(0, 1, (batch_size, hid_dim)))}
+
+schedule_name_to_scheduler = {'linear': get_linear_schedule_with_warmup, 'none': None}

+ 27 - 0
hivemind/server/layers/lr_schedule.py

@@ -0,0 +1,27 @@
+from torch.optim.lr_scheduler import LambdaLR
+
+
+# https://github.com/huggingface/transformers/blob/master/src/transformers/optimization.py
+def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
+    """
+    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+    Args:
+        optimizer (:class:`~torch.optim.Optimizer`):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (:obj:`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (:obj:`int`):
+            The total number of training steps.
+    Return:
+        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    def lr_lambda(current_step: int):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        return max(
+            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
+        )
+
+    return LambdaLR(optimizer, lr_lambda)

+ 1 - 1
hivemind/server/runtime.py

@@ -59,7 +59,7 @@ class Runtime(threading.Thread):
                 pool.start()
         if self.device is not None:
             for expert_backend in self.expert_backends.values():
-                expert_backend.to(self.device)
+                expert_backend.expert.to(self.device)
 
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
             try:

+ 10 - 6
hivemind/server/task_pool.py

@@ -3,11 +3,11 @@ Task pool is responsible for receiving tasks and grouping them together for proc
 """
 import ctypes
 import multiprocessing as mp
-import multiprocessing.context
 import os
 import threading
 import time
 import uuid
+from abc import ABCMeta, abstractmethod
 from collections import namedtuple
 from concurrent.futures import Future
 from queue import Empty
@@ -21,7 +21,7 @@ logger = get_logger(__name__)
 Task = namedtuple("Task", ("future", "args"))
 
 
-class TaskPoolBase(mp.context.ForkProcess):
+class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
     """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """
 
     def __init__(self, process_func: callable, daemon=True):
@@ -29,14 +29,17 @@ class TaskPoolBase(mp.context.ForkProcess):
         self.process_func = process_func
         self._priority = mp.Value(ctypes.c_double, 1.0)  # higher priority = the more urgent to process this pool
 
+    @abstractmethod
     def run(self):
-        raise NotImplementedError()
+        pass
 
+    @abstractmethod
     def submit_task(self, *args: torch.Tensor) -> Future:
-        raise NotImplementedError()
+        pass
 
+    @abstractmethod
     def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
-        raise NotImplementedError()
+        pass
 
     @property
     def priority(self):
@@ -47,8 +50,9 @@ class TaskPoolBase(mp.context.ForkProcess):
         self._priority.value = float(value)
 
     @property
+    @abstractmethod
     def empty(self):
-        raise NotImplementedError()
+        pass
 
 
 class TaskPool(TaskPoolBase):

+ 1 - 2
hivemind/utils/serializer.py

@@ -1,9 +1,8 @@
 """ A unified interface for several common serialization methods """
-from io import BytesIO
 from typing import Dict, Any
 
-import torch
 import msgpack
+
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)

+ 6 - 2
hivemind/utils/tensor_descr.py

@@ -8,6 +8,8 @@ from hivemind.proto.runtime_pb2 import CompressionType
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 
 warnings.filterwarnings("ignore", "CUDA initialization*", category=UserWarning)
+
+
 # ^-- cures https://github.com/pytorch/pytorch/issues/47038
 
 
@@ -32,11 +34,13 @@ class TensorDescriptor(DescriptorBase):
 
     @classmethod
     def from_tensor(cls, tensor: torch.Tensor):
-        return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, safe_check_pinned(tensor))
+        return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad,
+                   safe_check_pinned(tensor))
 
     def make_empty(self, **kwargs):
         properties = asdict(self)
         properties.update(kwargs)
+        properties.pop('compression')
         return torch.empty(**properties)
 
 
@@ -60,7 +64,7 @@ class BatchTensorDescriptor(TensorDescriptor):
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         return super().make_empty(size=(batch_size, *self.shape[1:]), **kwargs)
 
-    
+
 def safe_check_pinned(tensor: torch.Tensor) -> bool:
     """ Check whether or not a tensor is pinned. If torch cannot initialize cuda, returns False instead of error. """
     try:

+ 1 - 1
tests/benchmark_throughput.py

@@ -65,7 +65,7 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
         for i in range(num_experts):
             expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
             experts[f'expert{i}'] = hivemind.ExpertBackend(name=f'expert{i}',
-                                                           expert=expert, opt=torch.optim.Adam(expert.parameters()),
+                                                           expert=expert, optimizer=torch.optim.Adam(expert.parameters()),
                                                            args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
                                                            outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
                                                            max_batch_size=max_batch_size,

+ 0 - 76
tests/test_checkpoints.py

@@ -1,76 +0,0 @@
-from pathlib import Path
-from tempfile import TemporaryDirectory
-
-import pytest
-import torch
-from torch.nn import Linear
-
-from hivemind import BatchTensorDescriptor, ExpertBackend
-from hivemind.server.checkpoints import store_experts, load_weights
-
-EXPERT_WEIGHT_UPDATES = 3
-BACKWARD_PASSES_BEFORE_SAVE = 2
-BACKWARD_PASSES_AFTER_SAVE = 2
-
-
-@pytest.mark.forked
-def test_save_load_checkpoints():
-    experts = {}
-    expert = Linear(1, 1)
-    opt = torch.optim.SGD(expert.parameters(), 0.0)
-    expert_name = f'test_expert'
-    args_schema = (BatchTensorDescriptor(1),)
-    experts[expert_name] = ExpertBackend(name=expert_name, expert=expert, opt=opt,
-                                         args_schema=args_schema,
-                                         outputs_schema=BatchTensorDescriptor(1),
-                                         max_batch_size=1,
-                                         )
-    with TemporaryDirectory() as tmpdir:
-        tmp_path = Path(tmpdir)
-
-        for i in range(1, EXPERT_WEIGHT_UPDATES + 1):
-            expert.weight.data[0] = i
-            store_experts(experts, tmp_path)
-
-        checkpoints_dir = tmp_path / expert_name
-
-        assert checkpoints_dir.exists()
-        # include checkpoint_last.pt
-        assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1
-
-        expert.weight.data[0] = 0
-
-        load_weights(experts, tmp_path)
-        assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES
-
-
-@pytest.mark.forked
-def test_restore_update_count():
-    experts = {}
-    expert = Linear(1, 1)
-    opt = torch.optim.SGD(expert.parameters(), 0.0)
-    expert_name = f'test_expert'
-    args_schema = (BatchTensorDescriptor(1),)
-    expert_backend = ExpertBackend(name=expert_name, expert=expert, opt=opt,
-                                   args_schema=args_schema,
-                                   outputs_schema=BatchTensorDescriptor(1),
-                                   max_batch_size=1,
-                                   )
-    experts[expert_name] = expert_backend
-
-    batch = torch.randn(1, 1)
-    loss_grad = torch.randn(1, 1)
-
-    with TemporaryDirectory() as tmpdir:
-        tmp_path = Path(tmpdir)
-
-        for _ in range(BACKWARD_PASSES_BEFORE_SAVE):
-            expert_backend.backward(batch, loss_grad)
-
-        store_experts(experts, tmp_path)
-
-        for _ in range(BACKWARD_PASSES_AFTER_SAVE):
-            expert_backend.backward(batch, loss_grad)
-
-        load_weights(experts, tmp_path)
-        assert experts[expert_name].update_count == BACKWARD_PASSES_BEFORE_SAVE

+ 3 - 3
tests/test_dht_experts.py

@@ -21,7 +21,7 @@ def test_store_get_experts():
     first_peer = random.choice(peers)
     other_peer = random.choice(peers)
 
-    expert_uids = [f"my_expert.{i}" for i in range(110)]
+    expert_uids = [f"my_expert.{i}" for i in range(50)]
     batch_size = 10
     for batch_start in range(0, len(expert_uids), batch_size):
         hivemind.declare_experts(first_peer, expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
@@ -41,7 +41,7 @@ def test_store_get_experts():
 
 
 @pytest.mark.forked
-def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=16,
+def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=4,
                      grid_dims=(32, 32, 32)):
     dht = []
     for i in range(dht_size):
@@ -61,7 +61,7 @@ def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peer
     you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
     beam_search = MoEBeamSearcher(you, 'expert.', grid_dims)
 
-    for i in range(50):
+    for i in range(10):
         topk_experts = beam_search.find_best_experts([np.random.randn(dim) for dim in grid_dims], beam_size)
         assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
         assert len(topk_experts) == beam_size

+ 1 - 2
tests/test_dht_node.py

@@ -428,7 +428,6 @@ async def test_dhtnode_blacklist():
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
-
     node1 = await hivemind.DHTNode.create(blacklist_time=999)
     with pytest.raises(ValidationError):
         node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"],
@@ -441,7 +440,7 @@ async def test_dhtnode_edge_cases():
     peers = []
     for i in range(5):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
+        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=4))
 
     subkeys = [0, '', False, True, 'abyrvalg', 4555]
     keys = subkeys + [()]

+ 106 - 0
tests/test_expert_backend.py

@@ -0,0 +1,106 @@
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import pytest
+import torch
+from torch.nn import Linear
+
+from hivemind import BatchTensorDescriptor, ExpertBackend
+from hivemind.server.checkpoints import store_experts, load_experts
+from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
+
+EXPERT_WEIGHT_UPDATES = 3
+BACKWARD_PASSES_BEFORE_SAVE = 2
+BACKWARD_PASSES_AFTER_SAVE = 2
+EXPERT_NAME = 'test_expert'
+PEAK_LR = 1.0
+
+
+@pytest.fixture
+def example_experts():
+    expert = Linear(1, 1)
+    opt = torch.optim.SGD(expert.parameters(), PEAK_LR)
+
+    args_schema = (BatchTensorDescriptor(1),)
+    expert_backend = ExpertBackend(name=EXPERT_NAME, expert=expert, optimizer=opt,
+                                   scheduler=get_linear_schedule_with_warmup,
+                                   num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
+                                   num_training_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
+                                   args_schema=args_schema, outputs_schema=BatchTensorDescriptor(1), max_batch_size=1,
+                                   )
+    experts = {EXPERT_NAME: expert_backend}
+    yield experts
+
+
+@pytest.mark.forked
+def test_save_load_checkpoints(example_experts):
+    expert = example_experts[EXPERT_NAME].expert
+
+    with TemporaryDirectory() as tmpdir:
+        tmp_path = Path(tmpdir)
+
+        for i in range(1, EXPERT_WEIGHT_UPDATES + 1):
+            expert.weight.data[0] = i
+            store_experts(example_experts, tmp_path)
+
+        checkpoints_dir = tmp_path / EXPERT_NAME
+
+        assert checkpoints_dir.exists()
+        # include checkpoint_last.pt
+        assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1
+
+        expert.weight.data[0] = 0
+
+        load_experts(example_experts, tmp_path)
+        assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES
+
+
+@pytest.mark.forked
+def test_restore_update_count(example_experts):
+    expert_backend = example_experts[EXPERT_NAME]
+
+    batch = torch.randn(1, 1)
+    loss_grad = torch.randn(1, 1)
+
+    with TemporaryDirectory() as tmpdir:
+        tmp_path = Path(tmpdir)
+
+        for _ in range(BACKWARD_PASSES_BEFORE_SAVE):
+            expert_backend.backward(batch, loss_grad)
+
+        store_experts(example_experts, tmp_path)
+
+        for _ in range(BACKWARD_PASSES_AFTER_SAVE):
+            expert_backend.backward(batch, loss_grad)
+
+        load_experts(example_experts, tmp_path)
+        assert expert_backend.update_count == BACKWARD_PASSES_BEFORE_SAVE
+
+
+@pytest.mark.forked
+def test_lr_schedule(example_experts):
+    expert_backend = example_experts[EXPERT_NAME]
+    optimizer = expert_backend.optimizer
+
+    batch = torch.randn(1, 1)
+    loss_grad = torch.randn(1, 1)
+
+    with TemporaryDirectory() as tmpdir:
+        tmp_path = Path(tmpdir)
+
+        assert optimizer.param_groups[0]['lr'] == 0.0
+
+        for i in range(BACKWARD_PASSES_BEFORE_SAVE):
+            assert optimizer.param_groups[0]['lr'] == PEAK_LR * i / BACKWARD_PASSES_BEFORE_SAVE
+            expert_backend.backward(batch, loss_grad)
+
+        assert optimizer.param_groups[0]['lr'] == PEAK_LR
+        store_experts(example_experts, tmp_path)
+
+        for i in range(BACKWARD_PASSES_AFTER_SAVE):
+            assert optimizer.param_groups[0]['lr'] == PEAK_LR * (1 - (i / BACKWARD_PASSES_AFTER_SAVE))
+            expert_backend.backward(batch, loss_grad)
+
+        assert optimizer.param_groups[0]['lr'] == 0.0
+        load_experts(example_experts, tmp_path)
+        assert optimizer.param_groups[0]['lr'] == PEAK_LR

+ 6 - 6
tests/test_moe.py

@@ -12,15 +12,15 @@ from hivemind.server import layers
 @pytest.mark.forked
 def test_moe():
     all_expert_uids = [f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
-                       for _ in range(20)]
-    with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn',
-                           num_handlers=1, hidden_dim=16) as (server_endpoint, dht_endpoint):
+                       for _ in range(10)]
+    with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn', num_handlers=1,
+                           hidden_dim=16) as (server_endpoint, dht_endpoint):
         dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
 
         dmoe = hivemind.RemoteMixtureOfExperts(
             in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn.')
 
-        for i in range(10):
+        for i in range(5):
             out = dmoe(torch.randn(10, 16))
             out.sum().backward()
 
@@ -35,7 +35,7 @@ def test_call_many(hidden_dim=16):
     detect_anomalies = False
     atol = 1e-5
 
-    with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=hidden_dim,
+    with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=hidden_dim,
                            optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
@@ -182,7 +182,7 @@ def test_client_anomaly_detection():
     for i in range(4):
         expert = layers.name_to_block['ffn'](HID_DIM)
         experts[f'expert.{i}'] = hivemind.ExpertBackend(name=f'expert.{i}',
-                                                        expert=expert, opt=torch.optim.Adam(expert.parameters()),
+                                                        expert=expert, optimizer=torch.optim.Adam(expert.parameters()),
                                                         args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
                                                         outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
                                                         max_batch_size=16,

+ 3 - 3
tests/test_training.py

@@ -1,5 +1,4 @@
 from functools import partial
-from typing import Optional
 
 import pytest
 import torch
@@ -11,12 +10,13 @@ from hivemind import RemoteExpert, background_server
 
 
 @pytest.mark.forked
-def test_training(port: Optional[int] = None, max_steps: int = 100, threshold: float = 0.9):
+def test_training(max_steps: int = 100, threshold: float = 0.9):
     dataset = load_digits()
     X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
     SGD = partial(torch.optim.SGD, lr=0.05)
 
-    with background_server(num_experts=2, device='cpu', optim_cls=SGD, hidden_dim=64) as (server_endpoint, _):
+    with background_server(num_experts=2, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1,
+                           no_dht=True) as (server_endpoint, dht_endpoint):
         expert1 = RemoteExpert('expert.0', server_endpoint)
         expert2 = RemoteExpert('expert.1', server_endpoint)
         model = nn.Sequential(expert2, nn.Tanh(), expert1, nn.Linear(64, 10))