Bladeren bron

Load checkpoints on server start (#138)

* Load checkpoints on server start
Max Ryabinin 4 jaren geleden
bovenliggende
commit
d092810322

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.23'
+__version__ = '0.8.24'

+ 23 - 6
hivemind/server/__init__.py

@@ -7,12 +7,13 @@ import threading
 from contextlib import contextmanager
 from functools import partial
 from typing import Dict, Optional, Tuple, List
+from pathlib import Path
 
 import torch
 
 import hivemind
 from hivemind.dht import DHT
-from hivemind.server.checkpoint_saver import CheckpointSaver
+from hivemind.server.checkpoints import CheckpointSaver, load_weights, dir_is_correct
 from hivemind.server.connection_handler import ConnectionHandler
 from hivemind.server.dht_handler import DHTHandlerThread
 from hivemind.server.expert_backend import ExpertBackend
@@ -69,8 +70,8 @@ class Server(threading.Thread):
     @staticmethod
     def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
-               device=None, no_dht=False, initial_peers=(), dht_port=None,
-               compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
+               device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
+               load_experts=False, compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
         """
         Instantiate a server with several identical experts. See argparse comments below for details
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -91,6 +92,9 @@ class Server(threading.Thread):
         :param dht_port:  DHT node will listen on this port, default = find open port
            You can then use this node as initial peer for subsequent servers.
 
+        :param checkpoint_dir: directory to save expert checkpoints
+        :param load_experts: whether to load expert checkpoints from checkpoint_dir
+
         :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
             hosted on this server. For a more fine-grained compression, start server in python and specify compression
             for each BatchTensorProto in ExpertBackend for the respective experts.
@@ -100,8 +104,6 @@ class Server(threading.Thread):
         if len(kwargs) != 0:
             logger.info("Ignored kwargs:", kwargs)
         assert expert_cls in name_to_block
-        assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
-            "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
 
         if no_dht:
             dht = None
@@ -110,7 +112,19 @@ class Server(threading.Thread):
             dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
             logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
 
-        # get expert uids
+        if load_experts:
+            assert dir_is_correct(checkpoint_dir)
+            assert expert_uids is None, "Can't both load saved experts and create new ones from given UIDs"
+            expert_uids = [child.name for child in checkpoint_dir.iterdir() if (child / 'checkpoint_last.pt').exists()]
+            if expert_uids:
+                logger.info(f"Located checkpoints for experts {expert_uids}, ignoring UID generation options")
+            else:
+                logger.info(f"No expert checkpoints found in {checkpoint_dir}, generating...")
+
+        assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
+            "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
+
+        # get expert uids if not loaded previously
         if expert_uids is None:
             assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
             logger.info(f"Generating expert uids from pattern {expert_pattern}")
@@ -138,6 +152,9 @@ class Server(threading.Thread):
                                                          opt=optim_cls(expert.parameters()),
                                                          max_batch_size=max_batch_size)
 
+        if load_experts:
+            load_weights(experts, checkpoint_dir)
+
         server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
                         start=start)
         return server

+ 22 - 6
hivemind/server/checkpoint_saver.py → hivemind/server/checkpoints.py

@@ -4,37 +4,53 @@ from pathlib import Path
 from shutil import copytree
 from tempfile import TemporaryDirectory
 from typing import Dict
+import os
 
 import torch
 
 from hivemind.server.expert_backend import ExpertBackend
 
 
+def dir_is_correct(directory: Path):
+    assert directory is not None
+    assert directory.exists()
+    assert directory.is_dir()
+    return True
+
+
 class CheckpointSaver(threading.Thread):
     def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
         super().__init__()
+        assert dir_is_correct(checkpoint_dir)
         self.expert_backends = expert_backends
         self.update_period = update_period
         self.checkpoint_dir = checkpoint_dir
         self.stop = threading.Event()
 
+        # create expert directories to ensure that the directory is writable and checkpoints can be loaded
+        store_experts(self.expert_backends, self.checkpoint_dir)
+
     def run(self) -> None:
         while not self.stop.wait(self.update_period):
             store_experts(self.expert_backends, self.checkpoint_dir)
 
 
-def store_experts(experts: Dict[str, ExpertBackend], checkpoints_dir: Path):
+def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
+    assert dir_is_correct(checkpoint_dir)
     timestamp = datetime.now().isoformat(sep='_')
     with TemporaryDirectory() as tmpdirname:
         for expert_name, expert_backend in experts.items():
             expert_dir = Path(tmpdirname) / expert_name
             expert_dir.mkdir()
-            torch.save(expert_backend.state_dict(), expert_dir / f'checkpoint_{timestamp}.pt')
-        copytree(tmpdirname, str(checkpoints_dir), dirs_exist_ok=True)
+            checkpoint_name = expert_dir / f'checkpoint_{timestamp}.pt'
+            torch.save(expert_backend.state_dict(), checkpoint_name)
+            os.symlink(checkpoint_name, expert_dir / 'checkpoint_last.pt')
+        copytree(tmpdirname, str(checkpoint_dir), dirs_exist_ok=True)
 
 
-def load_weights(experts: Dict[str, ExpertBackend], checkpoints_dir: Path):
+def load_weights(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
+    assert dir_is_correct(checkpoint_dir)
     for expert_name, expert in experts.items():
-        checkpoints_folder = checkpoints_dir / expert_name
-        latest_checkpoint = max(checkpoints_folder.glob('checkpoint_*.pt'))
+        checkpoints_folder = checkpoint_dir / expert_name
+        latest_checkpoint = checkpoints_folder / 'checkpoint_last.pt'
         expert.load_state_dict(torch.load(latest_checkpoint))

+ 3 - 0
hivemind/server/expert_backend.py

@@ -61,6 +61,8 @@ class ExpertBackend(nn.Module):
         self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
         self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
 
+        self.register_buffer('update_count', torch.zeros(1, dtype=torch.long))
+
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
@@ -130,6 +132,7 @@ class ExpertBackend(nn.Module):
         """
         self.opt.step()
         self.opt.zero_grad()
+        self.update_count += 1
 
     def get_info(self) -> Dict[str, Any]:
         """ Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration. """

+ 4 - 0
scripts/run_server.py

@@ -1,4 +1,5 @@
 from functools import partial
+from pathlib import Path
 
 import configargparse
 import torch
@@ -43,6 +44,9 @@ def main():
                              'a server can spawn before hitting "Too many open files"; Use at your own risk.')
     parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression '
                         'parameter for grpc. Can be NONE, MEANSTD or FLOAT16')
+    parser.add_argument('--checkpoint_dir', type=Path, required=False, help='Directory to store expert checkpoints')
+    parser.add_argument('--load_experts', action='store_true', help='Load experts from the checkpoint directory')
+
     # fmt:on
     args = vars(parser.parse_args())
     args.pop('config', None)

+ 46 - 10
tests/test_checkpoints.py

@@ -1,13 +1,19 @@
 from pathlib import Path
 from tempfile import TemporaryDirectory
 
+import pytest
 import torch
 from torch.nn import Linear
 
 from hivemind import BatchTensorDescriptor, ExpertBackend
-from hivemind.server.checkpoint_saver import store_experts, load_weights
+from hivemind.server.checkpoints import store_experts, load_weights
 
+EXPERT_WEIGHT_UPDATES = 3
+BACKWARD_PASSES_BEFORE_SAVE = 2
+BACKWARD_PASSES_AFTER_SAVE = 2
 
+
+@pytest.mark.forked
 def test_save_load_checkpoints():
     experts = {}
     expert = Linear(1, 1)
@@ -22,19 +28,49 @@ def test_save_load_checkpoints():
     with TemporaryDirectory() as tmpdir:
         tmp_path = Path(tmpdir)
 
-        expert.weight.data[0] = 1
-        store_experts(experts, tmp_path)
-        expert.weight.data[0] = 2
-        store_experts(experts, tmp_path)
-        expert.weight.data[0] = 3
-        store_experts(experts, tmp_path)
+        for i in range(1, EXPERT_WEIGHT_UPDATES + 1):
+            expert.weight.data[0] = i
+            store_experts(experts, tmp_path)
 
         checkpoints_dir = tmp_path / expert_name
 
         assert checkpoints_dir.exists()
-        assert len(list(checkpoints_dir.iterdir())) == 3
+        # include checkpoint_last.pt
+        assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1
+
+        expert.weight.data[0] = 0
+
+        load_weights(experts, tmp_path)
+        assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES
+
+
+@pytest.mark.forked
+def test_restore_update_count():
+    experts = {}
+    expert = Linear(1, 1)
+    opt = torch.optim.SGD(expert.parameters(), 0.0)
+    expert_name = f'test_expert'
+    args_schema = (BatchTensorDescriptor(1),)
+    expert_backend = ExpertBackend(name=expert_name, expert=expert, opt=opt,
+                                   args_schema=args_schema,
+                                   outputs_schema=BatchTensorDescriptor(1),
+                                   max_batch_size=1,
+                                   )
+    experts[expert_name] = expert_backend
+
+    batch = torch.randn(1, 1)
+    loss_grad = torch.randn(1, 1)
+
+    with TemporaryDirectory() as tmpdir:
+        tmp_path = Path(tmpdir)
+
+        for _ in range(BACKWARD_PASSES_BEFORE_SAVE):
+            expert_backend.backward(batch, loss_grad)
+
+        store_experts(experts, tmp_path)
 
-        expert.weight.data[0] = 4
+        for _ in range(BACKWARD_PASSES_AFTER_SAVE):
+            expert_backend.backward(batch, loss_grad)
 
         load_weights(experts, tmp_path)
-        assert expert.weight.data[0] == 3
+        assert experts[expert_name].update_count == BACKWARD_PASSES_BEFORE_SAVE