Przeglądaj źródła

Support learning rate schedulers in ExpertBackend (#196)

* Add empty __init__ to hivemind_cli for correct package discovery

* Support learning rate schedulers in ExpertBackend

* Save/load full expert state

* Don't pass compression to make_empty

* spawn -> fork

* Remove load_expert_states

* Make TaskPoolBase an abstract class

* Output warning if some of the keys in state_dict are missing

Co-authored-by: justheuristic <justheuristic@gmail.com>
Max Ryabinin 4 lat temu
rodzic
commit
3024d381c5

+ 6 - 6
.circleci/config.yml

@@ -8,11 +8,11 @@ jobs:
       - checkout
       - checkout
       - restore_cache:
       - restore_cache:
           keys:
           keys:
-            - v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+            - py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
       - run: pip install -r requirements.txt
       - run: pip install -r requirements.txt
       - run: pip install -r requirements-dev.txt
       - run: pip install -r requirements-dev.txt
       - save_cache:
       - save_cache:
-          key: v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+          key: py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
           paths:
           paths:
             - '~/.cache/pip'
             - '~/.cache/pip'
       - run:
       - run:
@@ -28,11 +28,11 @@ jobs:
       - checkout
       - checkout
       - restore_cache:
       - restore_cache:
           keys:
           keys:
-            - v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+            - py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
       - run: pip install -r requirements.txt
       - run: pip install -r requirements.txt
       - run: pip install -r requirements-dev.txt
       - run: pip install -r requirements-dev.txt
       - save_cache:
       - save_cache:
-          key: v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+          key: py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
           paths:
           paths:
             - '~/.cache/pip'
             - '~/.cache/pip'
       - run:
       - run:
@@ -48,11 +48,11 @@ jobs:
       - checkout
       - checkout
       - restore_cache:
       - restore_cache:
           keys:
           keys:
-            - v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+            - py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
       - run: pip install -r requirements.txt
       - run: pip install -r requirements.txt
       - run: pip install -r requirements-dev.txt
       - run: pip install -r requirements-dev.txt
       - save_cache:
       - save_cache:
-          key: v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+          key: py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
           paths:
           paths:
             - '~/.cache/pip'
             - '~/.cache/pip'
       - run:
       - run:

+ 0 - 0
hivemind/hivemind_cli/__init__.py


+ 8 - 1
hivemind/hivemind_cli/run_server.py

@@ -8,6 +8,7 @@ from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.server import Server
 from hivemind.server import Server
 from hivemind.utils.threading import increase_file_limit
 from hivemind.utils.threading import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
+from hivemind.server.layers import schedule_name_to_scheduler
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -28,13 +29,20 @@ def main():
     parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
     parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
                         help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
                         help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
     parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
     parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
+
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
                         help='server will use this many processes to handle incoming requests')
                         help='server will use this many processes to handle incoming requests')
     parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
     parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
                         help='The total number of examples in the same batch will not exceed this value')
                         help='The total number of examples in the same batch will not exceed this value')
     parser.add_argument('--device', type=str, default=None, required=False,
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')
+
     parser.add_argument('--optimizer', type=str, default='adam', required=False, help='adam, sgd or none')
     parser.add_argument('--optimizer', type=str, default='adam', required=False, help='adam, sgd or none')
+    parser.add_argument('--scheduler', type=str, choices=schedule_name_to_scheduler.keys(), default='none',
+                        help='LR scheduler type to use')
+    parser.add_argument('--num-warmup-steps', type=int, required=False, help='the number of warmup steps for LR schedule')
+    parser.add_argument('--num-training-steps', type=int, required=False, help='the total number of steps for LR schedule')
+
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
                         help='one or more peers that can welcome you to the dht, e.g. 1.2.3.4:1337 192.132.231.4:4321')
                         help='one or more peers that can welcome you to the dht, e.g. 1.2.3.4:1337 192.132.231.4:4321')
@@ -45,7 +53,6 @@ def main():
     parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression '
     parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression '
                         'parameter for grpc. Can be NONE, MEANSTD or FLOAT16')
                         'parameter for grpc. Can be NONE, MEANSTD or FLOAT16')
     parser.add_argument('--checkpoint_dir', type=Path, required=False, help='Directory to store expert checkpoints')
     parser.add_argument('--checkpoint_dir', type=Path, required=False, help='Directory to store expert checkpoints')
-    parser.add_argument('--load_experts', action='store_true', help='Load experts from the checkpoint directory')
 
 
     # fmt:on
     # fmt:on
     args = vars(parser.parse_args())
     args = vars(parser.parse_args())

+ 48 - 94
hivemind/server/__init__.py

@@ -2,23 +2,22 @@ from __future__ import annotations
 
 
 import multiprocessing as mp
 import multiprocessing as mp
 import multiprocessing.synchronize
 import multiprocessing.synchronize
-import random
 import threading
 import threading
 from contextlib import contextmanager
 from contextlib import contextmanager
 from functools import partial
 from functools import partial
-from typing import Dict, Optional, Tuple, List
+from typing import Dict, Optional, Tuple
 from pathlib import Path
 from pathlib import Path
 
 
 import torch
 import torch
 
 
 import hivemind
 import hivemind
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.server.expert_uid import UID_DELIMITER
-from hivemind.server.checkpoints import CheckpointSaver, load_weights, dir_is_correct
+from hivemind.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
+from hivemind.server.checkpoints import CheckpointSaver, load_experts, is_directory
 from hivemind.server.connection_handler import ConnectionHandler
 from hivemind.server.connection_handler import ConnectionHandler
 from hivemind.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
 from hivemind.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
 from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.expert_backend import ExpertBackend
-from hivemind.server.layers import name_to_block, name_to_input
+from hivemind.server.layers import name_to_block, name_to_input, schedule_name_to_scheduler
 from hivemind.server.runtime import Runtime
 from hivemind.server.runtime import Runtime
 from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
 from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
 from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
 from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
@@ -68,11 +67,12 @@ class Server(threading.Thread):
         if start:
         if start:
             self.run_in_background(await_ready=True)
             self.run_in_background(await_ready=True)
 
 
-    @staticmethod
-    def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
-               expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
-               device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
-               load_experts=False, compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
+    @classmethod
+    def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
+               expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
+               num_warmup_steps=None, num_training_steps=None, num_handlers=None, max_batch_size=4096, device=None,
+               no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
+               compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
         """
         """
         Instantiate a server with several identical experts. See argparse comments below for details
         Instantiate a server with several identical experts. See argparse comments below for details
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -85,7 +85,12 @@ class Server(threading.Thread):
         :param num_handlers: server will use this many parallel processes to handle incoming requests
         :param num_handlers: server will use this many parallel processes to handle incoming requests
         :param max_batch_size: total num examples in the same batch will not exceed this value
         :param max_batch_size: total num examples in the same batch will not exceed this value
         :param device: all experts will use this device in torch notation; default: cuda if available else cpu
         :param device: all experts will use this device in torch notation; default: cuda if available else cpu
+
         :param optim_cls: uses this optimizer to train all experts
         :param optim_cls: uses this optimizer to train all experts
+        :param scheduler: if not `none`, the name of the expert LR scheduler
+        :param num_warmup_steps: the number of warmup steps for LR schedule
+        :param num_training_steps: the total number of steps for LR schedule
+
         :param no_dht: if specified, the server will not be attached to a dht
         :param no_dht: if specified, the server will not be attached to a dht
         :param initial_peers: a list of peers that will introduce this node to the dht,\
         :param initial_peers: a list of peers that will introduce this node to the dht,\
            e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers
            e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers
@@ -93,8 +98,7 @@ class Server(threading.Thread):
         :param dht_port:  DHT node will listen on this port, default = find open port
         :param dht_port:  DHT node will listen on this port, default = find open port
            You can then use this node as initial peer for subsequent servers.
            You can then use this node as initial peer for subsequent servers.
 
 
-        :param checkpoint_dir: directory to save expert checkpoints
-        :param load_experts: whether to load expert checkpoints from checkpoint_dir
+        :param checkpoint_dir: directory to save and load expert checkpoints
 
 
         :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
         :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
             hosted on this server. For a more fine-grained compression, start server in python and specify compression
             hosted on this server. For a more fine-grained compression, start server in python and specify compression
@@ -113,23 +117,29 @@ class Server(threading.Thread):
             dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
             dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
             logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
             logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
 
 
-        if load_experts:
-            assert dir_is_correct(checkpoint_dir)
-            assert expert_uids is None, "Can't both load saved experts and create new ones from given UIDs"
-            expert_uids = [child.name for child in checkpoint_dir.iterdir() if (child / 'checkpoint_last.pt').exists()]
-            if expert_uids:
-                logger.info(f"Located checkpoints for experts {expert_uids}, ignoring UID generation options")
-            else:
-                logger.info(f"No expert checkpoints found in {checkpoint_dir}, generating...")
-
-        assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
-            "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
+        assert ((expert_pattern is None and num_experts is None and expert_uids is not None) or
+                (num_experts is not None and expert_uids is None)), \
+            "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
 
 
-        # get expert uids if not loaded previously
         if expert_uids is None:
         if expert_uids is None:
-            assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
-            logger.info(f"Generating expert uids from pattern {expert_pattern}")
-            expert_uids = generate_uids_from_pattern(num_experts, expert_pattern, dht=dht)
+            if checkpoint_dir is not None:
+                assert is_directory(checkpoint_dir)
+                expert_uids = [child.name for child in checkpoint_dir.iterdir() if
+                               (child / 'checkpoint_last.pt').exists()]
+                total_experts_in_checkpoint = len(expert_uids)
+                logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
+
+                if total_experts_in_checkpoint > num_experts:
+                    raise ValueError(
+                        f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
+                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints.")
+            else:
+                expert_uids = []
+
+            uids_to_generate = num_experts - len(expert_uids)
+            if uids_to_generate > 0:
+                logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
+                expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))
 
 
         num_experts = len(expert_uids)
         num_experts = len(expert_uids)
         num_handlers = num_handlers if num_handlers is not None else num_experts * 8
         num_handlers = num_handlers if num_handlers is not None else num_experts * 8
@@ -142,6 +152,8 @@ class Server(threading.Thread):
         else:
         else:
             args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)
             args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)
 
 
+        scheduler = schedule_name_to_scheduler[scheduler]
+
         # initialize experts
         # initialize experts
         experts = {}
         experts = {}
         for expert_uid in expert_uids:
         for expert_uid in expert_uids:
@@ -150,15 +162,17 @@ class Server(threading.Thread):
                                                          args_schema=args_schema,
                                                          args_schema=args_schema,
                                                          outputs_schema=hivemind.BatchTensorDescriptor(
                                                          outputs_schema=hivemind.BatchTensorDescriptor(
                                                              hidden_dim, compression=compression),
                                                              hidden_dim, compression=compression),
-                                                         opt=optim_cls(expert.parameters()),
+                                                         optimizer=optim_cls(expert.parameters()),
+                                                         scheduler=scheduler,
+                                                         num_warmup_steps=num_warmup_steps,
+                                                         num_training_steps=num_training_steps,
                                                          max_batch_size=max_batch_size)
                                                          max_batch_size=max_batch_size)
 
 
-        if load_experts:
-            load_weights(experts, checkpoint_dir)
+        if checkpoint_dir is not None:
+            load_experts(experts, checkpoint_dir)
 
 
-        server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
-                        start=start)
-        return server
+        return cls(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
+                   checkpoint_dir=checkpoint_dir, start=start)
 
 
     def run(self):
     def run(self):
         """
         """
@@ -241,7 +255,7 @@ class Server(threading.Thread):
 def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
 def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     pipe, runners_pipe = mp.Pipe(duplex=True)
     pipe, runners_pipe = mp.Pipe(duplex=True)
-    runner = mp.get_context("spawn").Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
+    runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
 
 
     try:
     try:
         runner.start()
         runner.start()
@@ -269,63 +283,3 @@ def _server_runner(pipe, *args, **kwargs):
         server.shutdown()
         server.shutdown()
         server.join()
         server.join()
         logger.info("Server shut down.")
         logger.info("Server shut down.")
-
-
-def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None,
-                               attempts_per_expert=10) -> List[str]:
-    """
-    Sample experts from a given pattern, remove duplicates.
-    :param num_experts: sample this many unique expert uids
-    :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
-     means "sample random experts between myprefix.0.0 and myprefix.255.255;
-    :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
-    :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
-    :note: this method is not strictly process-safe. If several servers run it concurrently, they have
-     a small chance of sampling duplicate expert uids.
-    """
-    remaining_attempts = attempts_per_expert * num_experts
-    found_uids, attempted_uids = list(), set()
-
-    def _generate_uid():
-        if expert_pattern is None:
-            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
-
-        uid = []
-        for block in expert_pattern.split(UID_DELIMITER):
-            try:
-                if '[' not in block and ']' not in block:
-                    uid.append(block)
-                elif block.startswith('[') and block.endswith(']') and ':' in block:
-                    slice_start, slice_end = map(int, block[1:-1].split(':'))
-                    uid.append(str(random.randint(slice_start, slice_end - 1)))
-                else:
-                    raise ValueError("Block must be either fixed or a range [from:to]")
-            except KeyboardInterrupt as e:
-                raise e
-            except Exception as e:
-                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
-        return UID_DELIMITER.join(uid)
-
-    while remaining_attempts > 0 and len(found_uids) < num_experts:
-
-        # 1. sample new expert uids at random
-        new_uids = []
-        while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
-            new_uid = _generate_uid()
-            remaining_attempts -= 1
-            if new_uid not in attempted_uids:
-                attempted_uids.add(new_uid)
-                new_uids.append(new_uid)
-
-        # 2. look into DHT (if given) and remove duplicates
-        if dht:
-            existing_expert_uids = {found_expert.uid for found_expert in dht.get_experts(new_uids)
-                                    if found_expert is not None}
-            new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
-
-        found_uids += new_uids
-
-    if len(found_uids) != num_experts:
-        logger.warning(f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
-                       f"{attempts_per_expert * num_experts} attempts")
-    return found_uids

+ 15 - 8
hivemind/server/checkpoints.py

@@ -1,17 +1,20 @@
+import os
 import threading
 import threading
 from datetime import datetime
 from datetime import datetime
 from pathlib import Path
 from pathlib import Path
 from shutil import copy2
 from shutil import copy2
 from tempfile import TemporaryDirectory
 from tempfile import TemporaryDirectory
 from typing import Dict
 from typing import Dict
-import os
 
 
 import torch
 import torch
 
 
 from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.expert_backend import ExpertBackend
+from hivemind.utils.logging import get_logger
 
 
+logger = get_logger(__name__)
 
 
-def dir_is_correct(directory: Path):
+
+def is_directory(directory: Path):
     assert directory is not None
     assert directory is not None
     assert directory.exists()
     assert directory.exists()
     assert directory.is_dir()
     assert directory.is_dir()
@@ -33,7 +36,7 @@ def copy_tree(src: str, dst: str):
 class CheckpointSaver(threading.Thread):
 class CheckpointSaver(threading.Thread):
     def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
     def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
         super().__init__()
         super().__init__()
-        assert dir_is_correct(checkpoint_dir)
+        assert is_directory(checkpoint_dir)
         self.expert_backends = expert_backends
         self.expert_backends = expert_backends
         self.update_period = update_period
         self.update_period = update_period
         self.checkpoint_dir = checkpoint_dir
         self.checkpoint_dir = checkpoint_dir
@@ -48,21 +51,25 @@ class CheckpointSaver(threading.Thread):
 
 
 
 
 def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
 def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
-    assert dir_is_correct(checkpoint_dir)
+    logger.debug(f'Storing experts at {checkpoint_dir.absolute()}')
+    assert is_directory(checkpoint_dir)
     timestamp = datetime.now().isoformat(sep='_')
     timestamp = datetime.now().isoformat(sep='_')
     with TemporaryDirectory() as tmpdirname:
     with TemporaryDirectory() as tmpdirname:
         for expert_name, expert_backend in experts.items():
         for expert_name, expert_backend in experts.items():
             expert_dir = Path(tmpdirname) / expert_name
             expert_dir = Path(tmpdirname) / expert_name
             expert_dir.mkdir()
             expert_dir.mkdir()
             checkpoint_name = expert_dir / f'checkpoint_{timestamp}.pt'
             checkpoint_name = expert_dir / f'checkpoint_{timestamp}.pt'
-            torch.save(expert_backend.state_dict(), checkpoint_name)
+            torch.save(expert_backend.get_full_state(), checkpoint_name)
             os.symlink(checkpoint_name, expert_dir / 'checkpoint_last.pt')
             os.symlink(checkpoint_name, expert_dir / 'checkpoint_last.pt')
         copy_tree(tmpdirname, str(checkpoint_dir))
         copy_tree(tmpdirname, str(checkpoint_dir))
 
 
 
 
-def load_weights(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
-    assert dir_is_correct(checkpoint_dir)
+def load_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
+    assert is_directory(checkpoint_dir)
     for expert_name, expert in experts.items():
     for expert_name, expert in experts.items():
         checkpoints_folder = checkpoint_dir / expert_name
         checkpoints_folder = checkpoint_dir / expert_name
         latest_checkpoint = checkpoints_folder / 'checkpoint_last.pt'
         latest_checkpoint = checkpoints_folder / 'checkpoint_last.pt'
-        expert.load_state_dict(torch.load(latest_checkpoint))
+        if latest_checkpoint.exists():
+            expert.load_full_state(torch.load(latest_checkpoint))
+        else:
+            logger.warning(f'Failed to load checkpoint for expert {expert_name}')

+ 1 - 1
hivemind/server/connection_handler.py

@@ -16,7 +16,7 @@ from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class ConnectionHandler(mp.Process):
+class ConnectionHandler(mp.context.ForkProcess):
     """
     """
     A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
     A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
 
 

+ 74 - 12
hivemind/server/expert_backend.py

@@ -1,14 +1,17 @@
-from typing import Dict, Sequence, Any, Tuple, Union
+from typing import Dict, Sequence, Any, Tuple, Union, Callable
 
 
 import torch
 import torch
 from torch import nn
 from torch import nn
 
 
 from hivemind.server.task_pool import TaskPool
 from hivemind.server.task_pool import TaskPool
-from hivemind.utils import nested_flatten, nested_pack, nested_compare, nested_map, \
-    BatchTensorDescriptor, DUMMY_BATCH_SIZE
+from hivemind.utils import BatchTensorDescriptor, DUMMY_BATCH_SIZE
+from hivemind.utils.logging import get_logger
+from hivemind.utils.nested import nested_flatten, nested_pack, nested_compare, nested_map
 
 
+logger = get_logger(__name__)
 
 
-class ExpertBackend(nn.Module):
+
+class ExpertBackend:
     """
     """
     ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime
     ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime
     By default, ExpertBackend handles three types of requests:
     By default, ExpertBackend handles three types of requests:
@@ -26,20 +29,31 @@ class ExpertBackend(nn.Module):
         you should explicitly register these random variables as model inputs or outputs.
         you should explicitly register these random variables as model inputs or outputs.
         See hivemind.utils.custom_layers.DeterministicDropout for an example
         See hivemind.utils.custom_layers.DeterministicDropout for an example
 
 
-    :param opt: torch optimizer to be applied on every backward call
+    :param optimizer: torch optimizer to be applied on every backward call
+    :param scheduler: a function to create the learning rate scheduler for the expert
     :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
     :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
     :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
     :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
     :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
     :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
+    :param num_warmup_steps: the number of warmup steps for LR schedule
+    :param num_training_steps: the total number of steps for LR schedule
     :param kwargs: extra parameters to be forwarded into TaskPool.__init__
     :param kwargs: extra parameters to be forwarded into TaskPool.__init__
     """
     """
 
 
-    def __init__(self, name: str, expert: nn.Module, opt: torch.optim.Optimizer, *,
+    def __init__(self, name: str, expert: nn.Module, optimizer: torch.optim.Optimizer, *,
+                 scheduler: Callable = None,
                  args_schema: Tuple[BatchTensorDescriptor, ...] = None,
                  args_schema: Tuple[BatchTensorDescriptor, ...] = None,
                  kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
                  kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
                  outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
                  outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
+                 num_warmup_steps: int = None, num_training_steps: int = None,
                  **kwargs):
                  **kwargs):
         super().__init__()
         super().__init__()
-        self.expert, self.opt, self.name = expert, opt, name
+        self.expert, self.optimizer, self.name = expert, optimizer, name
+
+        if scheduler is None:
+            self.scheduler = None
+        else:
+            assert optimizer is not None and num_warmup_steps is not None and num_training_steps is not None
+            self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_training_steps)
 
 
         self.args_schema = args_schema = tuple(args_schema or ())
         self.args_schema = args_schema = tuple(args_schema or ())
         self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
         self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
@@ -61,7 +75,8 @@ class ExpertBackend(nn.Module):
         self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
         self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
         self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
         self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
 
 
-        self.register_buffer('update_count', torch.zeros(1, dtype=torch.long))
+        self.update_count = 0
+        self.examples_processed = 0
 
 
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         """
@@ -111,6 +126,8 @@ class ExpertBackend(nn.Module):
                                   if tensor.is_floating_point() else tensor.detach())
                                   if tensor.is_floating_point() else tensor.detach())
                       for input_key, tensor in kwargs.items()}
                       for input_key, tensor in kwargs.items()}
 
 
+            batch_size = args[0].size(0)
+
             outputs = self.expert(*args, **kwargs)
             outputs = self.expert(*args, **kwargs)
             assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
             assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
 
 
@@ -121,18 +138,63 @@ class ExpertBackend(nn.Module):
                 nested_flatten(grad_outputs), outputs_flat))
                 nested_flatten(grad_outputs), outputs_flat))
             torch.autograd.backward(outputs_flat, grad_tensors=grad_outputs_flat,
             torch.autograd.backward(outputs_flat, grad_tensors=grad_outputs_flat,
                                     create_graph=False, retain_graph=False)
                                     create_graph=False, retain_graph=False)
-            self.apply_gradients()
+            self.apply_gradients(batch_size)
 
 
         return tuple(x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x)
         return tuple(x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x)
                      for x in nested_flatten((args, kwargs)))
                      for x in nested_flatten((args, kwargs)))
 
 
-    def apply_gradients(self) -> None:
+    def apply_gradients(self, batch_size) -> None:
         """
         """
         Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
         Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
         """
         """
-        self.opt.step()
-        self.opt.zero_grad()
+        self.optimizer.step()
+        self.optimizer.zero_grad()
+
+        if self.scheduler is not None:
+            self.scheduler.step()
+
         self.update_count += 1
         self.update_count += 1
+        self.examples_processed += batch_size
+
+    def get_stats(self) -> Dict:
+        """
+        Return current expert training statistics (number of updates, number of processed examples after last optimizer step)
+        """
+        return {
+            'updates': self.update_count,
+            'examples_processed': self.examples_processed
+        }
+
+    def get_full_state(self) -> Dict:
+        """
+        Return the current state of the expert (including batch processing statistics)
+        """
+        full_state = {
+            'stats': self.get_stats(),
+            'model': self.expert.state_dict(),
+            'optimizer': self.optimizer.state_dict(),
+            'scheduler': {} if self.scheduler is None else self.scheduler.state_dict()
+        }
+        return full_state
+
+    def load_full_state(self, state_dict: Dict):
+        if 'stats' in state_dict:
+            self.update_count = state_dict['stats']['updates']
+            self.examples_processed = state_dict['stats']['examples_processed']
+        else:
+            logger.warning(f'Batch processing stats missing for expert {self.name}')
+
+        self.expert.load_state_dict(state_dict['model'])
+
+        if 'optimizer' in state_dict:
+            self.optimizer.load_state_dict(state_dict['optimizer'])
+        else:
+            logger.warning(f'Optimizer state missing for expert {self.name}')
+
+        if self.scheduler is not None and 'scheduler' in state_dict:
+            self.scheduler.load_state_dict(state_dict['scheduler'])
+        else:
+            logger.warning(f'Learning rate scheduler state missing for expert {self.name}')
 
 
     def get_info(self) -> Dict[str, Any]:
     def get_info(self) -> Dict[str, Any]:
         """ Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration. """
         """ Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration. """

+ 63 - 2
hivemind/server/expert_uid.py

@@ -1,8 +1,11 @@
+import random
 import re
 import re
-from typing import NamedTuple, Union, Tuple
+from typing import NamedTuple, Union, Tuple, Optional, List
 
 
-from hivemind.utils.networking import Endpoint
+from hivemind.dht import DHT
+from hivemind.utils import Endpoint, get_logger
 
 
+logger = get_logger(__name__)
 
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
 UidEndpoint = NamedTuple("UidEndpoint", [('uid', ExpertUID), ('endpoint', Endpoint)])
 UidEndpoint = NamedTuple("UidEndpoint", [('uid', ExpertUID), ('endpoint', Endpoint)])
@@ -30,3 +33,61 @@ def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPref
     return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
     return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
 
 
 
 
+def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None,
+                               attempts_per_expert=10) -> List[str]:
+    """
+    Sample experts from a given pattern, remove duplicates.
+    :param num_experts: sample this many unique expert uids
+    :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
+     means "sample random experts between myprefix.0.0 and myprefix.255.255;
+    :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
+    :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
+    :note: this method is not strictly process-safe. If several servers run it concurrently, they have
+     a small chance of sampling duplicate expert uids.
+    """
+    remaining_attempts = attempts_per_expert * num_experts
+    found_uids, attempted_uids = list(), set()
+
+    def _generate_uid():
+        if expert_pattern is None:
+            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
+
+        uid = []
+        for block in expert_pattern.split(UID_DELIMITER):
+            try:
+                if '[' not in block and ']' not in block:
+                    uid.append(block)
+                elif block.startswith('[') and block.endswith(']') and ':' in block:
+                    slice_start, slice_end = map(int, block[1:-1].split(':'))
+                    uid.append(str(random.randint(slice_start, slice_end - 1)))
+                else:
+                    raise ValueError("Block must be either fixed or a range [from:to]")
+            except KeyboardInterrupt as e:
+                raise e
+            except Exception as e:
+                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
+        return UID_DELIMITER.join(uid)
+
+    while remaining_attempts > 0 and len(found_uids) < num_experts:
+
+        # 1. sample new expert uids at random
+        new_uids = []
+        while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
+            new_uid = _generate_uid()
+            remaining_attempts -= 1
+            if new_uid not in attempted_uids:
+                attempted_uids.add(new_uid)
+                new_uids.append(new_uid)
+
+        # 2. look into DHT (if given) and remove duplicates
+        if dht:
+            existing_expert_uids = {found_expert.uid for found_expert in dht.get_experts(new_uids)
+                                    if found_expert is not None}
+            new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
+
+        found_uids += new_uids
+
+    if len(found_uids) != num_experts:
+        logger.warning(f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
+                       f"{attempts_per_expert * num_experts} attempts")
+    return found_uids

+ 4 - 1
hivemind/server/layers/__init__.py

@@ -2,6 +2,7 @@ import torch
 
 
 from hivemind.server.layers.common import FeedforwardBlock, TransformerEncoderLayer, NopExpert
 from hivemind.server.layers.common import FeedforwardBlock, TransformerEncoderLayer, NopExpert
 from hivemind.server.layers.dropout import DeterministicDropout, DeterministicDropoutNetwork
 from hivemind.server.layers.dropout import DeterministicDropout, DeterministicDropoutNetwork
+from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
 
 
 name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim),
 name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim),
                  'transformer': lambda hid_dim: TransformerEncoderLayer(hid_dim, dim_feedforward=4 * hid_dim, nhead=16),
                  'transformer': lambda hid_dim: TransformerEncoderLayer(hid_dim, dim_feedforward=4 * hid_dim, nhead=16),
@@ -10,7 +11,9 @@ name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim),
 
 
 name_to_input = {'ffn': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
 name_to_input = {'ffn': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
                  'transformer': lambda batch_size, hid_dim:
                  'transformer': lambda batch_size, hid_dim:
-                 (torch.empty((batch_size, 128, hid_dim)), torch.empty((batch_size, hid_dim), dtype=torch.bool)),
+                 (torch.empty((batch_size, 128, hid_dim)), torch.empty((batch_size, 128), dtype=torch.bool)),
                  'nop': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
                  'nop': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
                  'det_dropout': lambda batch_size, hid_dim:
                  'det_dropout': lambda batch_size, hid_dim:
                  (torch.empty((batch_size, hid_dim)), torch.randint(0, 1, (batch_size, hid_dim)))}
                  (torch.empty((batch_size, hid_dim)), torch.randint(0, 1, (batch_size, hid_dim)))}
+
+schedule_name_to_scheduler = {'linear': get_linear_schedule_with_warmup, 'none': None}

+ 27 - 0
hivemind/server/layers/lr_schedule.py

@@ -0,0 +1,27 @@
+from torch.optim.lr_scheduler import LambdaLR
+
+
+# https://github.com/huggingface/transformers/blob/master/src/transformers/optimization.py
+def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
+    """
+    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+    Args:
+        optimizer (:class:`~torch.optim.Optimizer`):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (:obj:`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (:obj:`int`):
+            The total number of training steps.
+    Return:
+        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    def lr_lambda(current_step: int):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        return max(
+            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
+        )
+
+    return LambdaLR(optimizer, lr_lambda)

+ 1 - 1
hivemind/server/runtime.py

@@ -59,7 +59,7 @@ class Runtime(threading.Thread):
                 pool.start()
                 pool.start()
         if self.device is not None:
         if self.device is not None:
             for expert_backend in self.expert_backends.values():
             for expert_backend in self.expert_backends.values():
-                expert_backend.to(self.device)
+                expert_backend.expert.to(self.device)
 
 
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
             try:
             try:

+ 10 - 6
hivemind/server/task_pool.py

@@ -3,11 +3,11 @@ Task pool is responsible for receiving tasks and grouping them together for proc
 """
 """
 import ctypes
 import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
-import multiprocessing.context
 import os
 import os
 import threading
 import threading
 import time
 import time
 import uuid
 import uuid
+from abc import ABCMeta, abstractmethod
 from collections import namedtuple
 from collections import namedtuple
 from concurrent.futures import Future
 from concurrent.futures import Future
 from queue import Empty
 from queue import Empty
@@ -21,7 +21,7 @@ logger = get_logger(__name__)
 Task = namedtuple("Task", ("future", "args"))
 Task = namedtuple("Task", ("future", "args"))
 
 
 
 
-class TaskPoolBase(mp.context.ForkProcess):
+class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
     """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """
     """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """
 
 
     def __init__(self, process_func: callable, daemon=True):
     def __init__(self, process_func: callable, daemon=True):
@@ -29,14 +29,17 @@ class TaskPoolBase(mp.context.ForkProcess):
         self.process_func = process_func
         self.process_func = process_func
         self._priority = mp.Value(ctypes.c_double, 1.0)  # higher priority = the more urgent to process this pool
         self._priority = mp.Value(ctypes.c_double, 1.0)  # higher priority = the more urgent to process this pool
 
 
+    @abstractmethod
     def run(self):
     def run(self):
-        raise NotImplementedError()
+        pass
 
 
+    @abstractmethod
     def submit_task(self, *args: torch.Tensor) -> Future:
     def submit_task(self, *args: torch.Tensor) -> Future:
-        raise NotImplementedError()
+        pass
 
 
+    @abstractmethod
     def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
     def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
-        raise NotImplementedError()
+        pass
 
 
     @property
     @property
     def priority(self):
     def priority(self):
@@ -47,8 +50,9 @@ class TaskPoolBase(mp.context.ForkProcess):
         self._priority.value = float(value)
         self._priority.value = float(value)
 
 
     @property
     @property
+    @abstractmethod
     def empty(self):
     def empty(self):
-        raise NotImplementedError()
+        pass
 
 
 
 
 class TaskPool(TaskPoolBase):
 class TaskPool(TaskPoolBase):

+ 1 - 2
hivemind/utils/serializer.py

@@ -1,9 +1,8 @@
 """ A unified interface for several common serialization methods """
 """ A unified interface for several common serialization methods """
-from io import BytesIO
 from typing import Dict, Any
 from typing import Dict, Any
 
 
-import torch
 import msgpack
 import msgpack
+
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 6 - 2
hivemind/utils/tensor_descr.py

@@ -8,6 +8,8 @@ from hivemind.proto.runtime_pb2 import CompressionType
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 
 
 warnings.filterwarnings("ignore", "CUDA initialization*", category=UserWarning)
 warnings.filterwarnings("ignore", "CUDA initialization*", category=UserWarning)
+
+
 # ^-- cures https://github.com/pytorch/pytorch/issues/47038
 # ^-- cures https://github.com/pytorch/pytorch/issues/47038
 
 
 
 
@@ -32,11 +34,13 @@ class TensorDescriptor(DescriptorBase):
 
 
     @classmethod
     @classmethod
     def from_tensor(cls, tensor: torch.Tensor):
     def from_tensor(cls, tensor: torch.Tensor):
-        return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, safe_check_pinned(tensor))
+        return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad,
+                   safe_check_pinned(tensor))
 
 
     def make_empty(self, **kwargs):
     def make_empty(self, **kwargs):
         properties = asdict(self)
         properties = asdict(self)
         properties.update(kwargs)
         properties.update(kwargs)
+        properties.pop('compression')
         return torch.empty(**properties)
         return torch.empty(**properties)
 
 
 
 
@@ -60,7 +64,7 @@ class BatchTensorDescriptor(TensorDescriptor):
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         return super().make_empty(size=(batch_size, *self.shape[1:]), **kwargs)
         return super().make_empty(size=(batch_size, *self.shape[1:]), **kwargs)
 
 
-    
+
 def safe_check_pinned(tensor: torch.Tensor) -> bool:
 def safe_check_pinned(tensor: torch.Tensor) -> bool:
     """ Check whether or not a tensor is pinned. If torch cannot initialize cuda, returns False instead of error. """
     """ Check whether or not a tensor is pinned. If torch cannot initialize cuda, returns False instead of error. """
     try:
     try:

+ 1 - 1
tests/benchmark_throughput.py

@@ -65,7 +65,7 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
         for i in range(num_experts):
         for i in range(num_experts):
             expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
             expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
             experts[f'expert{i}'] = hivemind.ExpertBackend(name=f'expert{i}',
             experts[f'expert{i}'] = hivemind.ExpertBackend(name=f'expert{i}',
-                                                           expert=expert, opt=torch.optim.Adam(expert.parameters()),
+                                                           expert=expert, optimizer=torch.optim.Adam(expert.parameters()),
                                                            args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
                                                            args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
                                                            outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
                                                            outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
                                                            max_batch_size=max_batch_size,
                                                            max_batch_size=max_batch_size,

+ 0 - 76
tests/test_checkpoints.py

@@ -1,76 +0,0 @@
-from pathlib import Path
-from tempfile import TemporaryDirectory
-
-import pytest
-import torch
-from torch.nn import Linear
-
-from hivemind import BatchTensorDescriptor, ExpertBackend
-from hivemind.server.checkpoints import store_experts, load_weights
-
-EXPERT_WEIGHT_UPDATES = 3
-BACKWARD_PASSES_BEFORE_SAVE = 2
-BACKWARD_PASSES_AFTER_SAVE = 2
-
-
-@pytest.mark.forked
-def test_save_load_checkpoints():
-    experts = {}
-    expert = Linear(1, 1)
-    opt = torch.optim.SGD(expert.parameters(), 0.0)
-    expert_name = f'test_expert'
-    args_schema = (BatchTensorDescriptor(1),)
-    experts[expert_name] = ExpertBackend(name=expert_name, expert=expert, opt=opt,
-                                         args_schema=args_schema,
-                                         outputs_schema=BatchTensorDescriptor(1),
-                                         max_batch_size=1,
-                                         )
-    with TemporaryDirectory() as tmpdir:
-        tmp_path = Path(tmpdir)
-
-        for i in range(1, EXPERT_WEIGHT_UPDATES + 1):
-            expert.weight.data[0] = i
-            store_experts(experts, tmp_path)
-
-        checkpoints_dir = tmp_path / expert_name
-
-        assert checkpoints_dir.exists()
-        # include checkpoint_last.pt
-        assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1
-
-        expert.weight.data[0] = 0
-
-        load_weights(experts, tmp_path)
-        assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES
-
-
-@pytest.mark.forked
-def test_restore_update_count():
-    experts = {}
-    expert = Linear(1, 1)
-    opt = torch.optim.SGD(expert.parameters(), 0.0)
-    expert_name = f'test_expert'
-    args_schema = (BatchTensorDescriptor(1),)
-    expert_backend = ExpertBackend(name=expert_name, expert=expert, opt=opt,
-                                   args_schema=args_schema,
-                                   outputs_schema=BatchTensorDescriptor(1),
-                                   max_batch_size=1,
-                                   )
-    experts[expert_name] = expert_backend
-
-    batch = torch.randn(1, 1)
-    loss_grad = torch.randn(1, 1)
-
-    with TemporaryDirectory() as tmpdir:
-        tmp_path = Path(tmpdir)
-
-        for _ in range(BACKWARD_PASSES_BEFORE_SAVE):
-            expert_backend.backward(batch, loss_grad)
-
-        store_experts(experts, tmp_path)
-
-        for _ in range(BACKWARD_PASSES_AFTER_SAVE):
-            expert_backend.backward(batch, loss_grad)
-
-        load_weights(experts, tmp_path)
-        assert experts[expert_name].update_count == BACKWARD_PASSES_BEFORE_SAVE

+ 3 - 3
tests/test_dht_experts.py

@@ -21,7 +21,7 @@ def test_store_get_experts():
     first_peer = random.choice(peers)
     first_peer = random.choice(peers)
     other_peer = random.choice(peers)
     other_peer = random.choice(peers)
 
 
-    expert_uids = [f"my_expert.{i}" for i in range(110)]
+    expert_uids = [f"my_expert.{i}" for i in range(50)]
     batch_size = 10
     batch_size = 10
     for batch_start in range(0, len(expert_uids), batch_size):
     for batch_start in range(0, len(expert_uids), batch_size):
         hivemind.declare_experts(first_peer, expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
         hivemind.declare_experts(first_peer, expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
@@ -41,7 +41,7 @@ def test_store_get_experts():
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=16,
+def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=4,
                      grid_dims=(32, 32, 32)):
                      grid_dims=(32, 32, 32)):
     dht = []
     dht = []
     for i in range(dht_size):
     for i in range(dht_size):
@@ -61,7 +61,7 @@ def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peer
     you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
     you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
     beam_search = MoEBeamSearcher(you, 'expert.', grid_dims)
     beam_search = MoEBeamSearcher(you, 'expert.', grid_dims)
 
 
-    for i in range(50):
+    for i in range(10):
         topk_experts = beam_search.find_best_experts([np.random.randn(dim) for dim in grid_dims], beam_size)
         topk_experts = beam_search.find_best_experts([np.random.randn(dim) for dim in grid_dims], beam_size)
         assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
         assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
         assert len(topk_experts) == beam_size
         assert len(topk_experts) == beam_size

+ 1 - 2
tests/test_dht_node.py

@@ -428,7 +428,6 @@ async def test_dhtnode_blacklist():
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
 async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
-
     node1 = await hivemind.DHTNode.create(blacklist_time=999)
     node1 = await hivemind.DHTNode.create(blacklist_time=999)
     with pytest.raises(ValidationError):
     with pytest.raises(ValidationError):
         node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"],
         node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"],
@@ -441,7 +440,7 @@ async def test_dhtnode_edge_cases():
     peers = []
     peers = []
     for i in range(5):
     for i in range(5):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
+        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=4))
 
 
     subkeys = [0, '', False, True, 'abyrvalg', 4555]
     subkeys = [0, '', False, True, 'abyrvalg', 4555]
     keys = subkeys + [()]
     keys = subkeys + [()]

+ 106 - 0
tests/test_expert_backend.py

@@ -0,0 +1,106 @@
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import pytest
+import torch
+from torch.nn import Linear
+
+from hivemind import BatchTensorDescriptor, ExpertBackend
+from hivemind.server.checkpoints import store_experts, load_experts
+from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
+
+EXPERT_WEIGHT_UPDATES = 3
+BACKWARD_PASSES_BEFORE_SAVE = 2
+BACKWARD_PASSES_AFTER_SAVE = 2
+EXPERT_NAME = 'test_expert'
+PEAK_LR = 1.0
+
+
+@pytest.fixture
+def example_experts():
+    expert = Linear(1, 1)
+    opt = torch.optim.SGD(expert.parameters(), PEAK_LR)
+
+    args_schema = (BatchTensorDescriptor(1),)
+    expert_backend = ExpertBackend(name=EXPERT_NAME, expert=expert, optimizer=opt,
+                                   scheduler=get_linear_schedule_with_warmup,
+                                   num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
+                                   num_training_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
+                                   args_schema=args_schema, outputs_schema=BatchTensorDescriptor(1), max_batch_size=1,
+                                   )
+    experts = {EXPERT_NAME: expert_backend}
+    yield experts
+
+
+@pytest.mark.forked
+def test_save_load_checkpoints(example_experts):
+    expert = example_experts[EXPERT_NAME].expert
+
+    with TemporaryDirectory() as tmpdir:
+        tmp_path = Path(tmpdir)
+
+        for i in range(1, EXPERT_WEIGHT_UPDATES + 1):
+            expert.weight.data[0] = i
+            store_experts(example_experts, tmp_path)
+
+        checkpoints_dir = tmp_path / EXPERT_NAME
+
+        assert checkpoints_dir.exists()
+        # include checkpoint_last.pt
+        assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1
+
+        expert.weight.data[0] = 0
+
+        load_experts(example_experts, tmp_path)
+        assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES
+
+
+@pytest.mark.forked
+def test_restore_update_count(example_experts):
+    expert_backend = example_experts[EXPERT_NAME]
+
+    batch = torch.randn(1, 1)
+    loss_grad = torch.randn(1, 1)
+
+    with TemporaryDirectory() as tmpdir:
+        tmp_path = Path(tmpdir)
+
+        for _ in range(BACKWARD_PASSES_BEFORE_SAVE):
+            expert_backend.backward(batch, loss_grad)
+
+        store_experts(example_experts, tmp_path)
+
+        for _ in range(BACKWARD_PASSES_AFTER_SAVE):
+            expert_backend.backward(batch, loss_grad)
+
+        load_experts(example_experts, tmp_path)
+        assert expert_backend.update_count == BACKWARD_PASSES_BEFORE_SAVE
+
+
+@pytest.mark.forked
+def test_lr_schedule(example_experts):
+    expert_backend = example_experts[EXPERT_NAME]
+    optimizer = expert_backend.optimizer
+
+    batch = torch.randn(1, 1)
+    loss_grad = torch.randn(1, 1)
+
+    with TemporaryDirectory() as tmpdir:
+        tmp_path = Path(tmpdir)
+
+        assert optimizer.param_groups[0]['lr'] == 0.0
+
+        for i in range(BACKWARD_PASSES_BEFORE_SAVE):
+            assert optimizer.param_groups[0]['lr'] == PEAK_LR * i / BACKWARD_PASSES_BEFORE_SAVE
+            expert_backend.backward(batch, loss_grad)
+
+        assert optimizer.param_groups[0]['lr'] == PEAK_LR
+        store_experts(example_experts, tmp_path)
+
+        for i in range(BACKWARD_PASSES_AFTER_SAVE):
+            assert optimizer.param_groups[0]['lr'] == PEAK_LR * (1 - (i / BACKWARD_PASSES_AFTER_SAVE))
+            expert_backend.backward(batch, loss_grad)
+
+        assert optimizer.param_groups[0]['lr'] == 0.0
+        load_experts(example_experts, tmp_path)
+        assert optimizer.param_groups[0]['lr'] == PEAK_LR

+ 6 - 6
tests/test_moe.py

@@ -12,15 +12,15 @@ from hivemind.server import layers
 @pytest.mark.forked
 @pytest.mark.forked
 def test_moe():
 def test_moe():
     all_expert_uids = [f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
     all_expert_uids = [f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
-                       for _ in range(20)]
-    with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn',
-                           num_handlers=1, hidden_dim=16) as (server_endpoint, dht_endpoint):
+                       for _ in range(10)]
+    with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn', num_handlers=1,
+                           hidden_dim=16) as (server_endpoint, dht_endpoint):
         dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
         dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
 
 
         dmoe = hivemind.RemoteMixtureOfExperts(
         dmoe = hivemind.RemoteMixtureOfExperts(
             in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn.')
             in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn.')
 
 
-        for i in range(10):
+        for i in range(5):
             out = dmoe(torch.randn(10, 16))
             out = dmoe(torch.randn(10, 16))
             out.sum().backward()
             out.sum().backward()
 
 
@@ -35,7 +35,7 @@ def test_call_many(hidden_dim=16):
     detect_anomalies = False
     detect_anomalies = False
     atol = 1e-5
     atol = 1e-5
 
 
-    with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=hidden_dim,
+    with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=hidden_dim,
                            optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
                            optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
@@ -182,7 +182,7 @@ def test_client_anomaly_detection():
     for i in range(4):
     for i in range(4):
         expert = layers.name_to_block['ffn'](HID_DIM)
         expert = layers.name_to_block['ffn'](HID_DIM)
         experts[f'expert.{i}'] = hivemind.ExpertBackend(name=f'expert.{i}',
         experts[f'expert.{i}'] = hivemind.ExpertBackend(name=f'expert.{i}',
-                                                        expert=expert, opt=torch.optim.Adam(expert.parameters()),
+                                                        expert=expert, optimizer=torch.optim.Adam(expert.parameters()),
                                                         args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
                                                         args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
                                                         outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
                                                         outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
                                                         max_batch_size=16,
                                                         max_batch_size=16,

+ 3 - 3
tests/test_training.py

@@ -1,5 +1,4 @@
 from functools import partial
 from functools import partial
-from typing import Optional
 
 
 import pytest
 import pytest
 import torch
 import torch
@@ -11,12 +10,13 @@ from hivemind import RemoteExpert, background_server
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_training(port: Optional[int] = None, max_steps: int = 100, threshold: float = 0.9):
+def test_training(max_steps: int = 100, threshold: float = 0.9):
     dataset = load_digits()
     dataset = load_digits()
     X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
     X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
     SGD = partial(torch.optim.SGD, lr=0.05)
     SGD = partial(torch.optim.SGD, lr=0.05)
 
 
-    with background_server(num_experts=2, device='cpu', optim_cls=SGD, hidden_dim=64) as (server_endpoint, _):
+    with background_server(num_experts=2, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1,
+                           no_dht=True) as (server_endpoint, dht_endpoint):
         expert1 = RemoteExpert('expert.0', server_endpoint)
         expert1 = RemoteExpert('expert.0', server_endpoint)
         expert2 = RemoteExpert('expert.1', server_endpoint)
         expert2 = RemoteExpert('expert.1', server_endpoint)
         model = nn.Sequential(expert2, nn.Tanh(), expert1, nn.Linear(64, 10))
         model = nn.Sequential(expert2, nn.Tanh(), expert1, nn.Linear(64, 10))