Browse Source

Load checkpoints on server start (#138)

* Load checkpoints on server start
Max Ryabinin 4 years ago
parent
commit
d092810322

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.23'
+__version__ = '0.8.24'

+ 23 - 6
hivemind/server/__init__.py

@@ -7,12 +7,13 @@ import threading
 from contextlib import contextmanager
 from contextlib import contextmanager
 from functools import partial
 from functools import partial
 from typing import Dict, Optional, Tuple, List
 from typing import Dict, Optional, Tuple, List
+from pathlib import Path
 
 
 import torch
 import torch
 
 
 import hivemind
 import hivemind
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.server.checkpoint_saver import CheckpointSaver
+from hivemind.server.checkpoints import CheckpointSaver, load_weights, dir_is_correct
 from hivemind.server.connection_handler import ConnectionHandler
 from hivemind.server.connection_handler import ConnectionHandler
 from hivemind.server.dht_handler import DHTHandlerThread
 from hivemind.server.dht_handler import DHTHandlerThread
 from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.expert_backend import ExpertBackend
@@ -69,8 +70,8 @@ class Server(threading.Thread):
     @staticmethod
     @staticmethod
     def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
     def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
-               device=None, no_dht=False, initial_peers=(), dht_port=None,
-               compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
+               device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
+               load_experts=False, compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
         """
         """
         Instantiate a server with several identical experts. See argparse comments below for details
         Instantiate a server with several identical experts. See argparse comments below for details
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -91,6 +92,9 @@ class Server(threading.Thread):
         :param dht_port:  DHT node will listen on this port, default = find open port
         :param dht_port:  DHT node will listen on this port, default = find open port
            You can then use this node as initial peer for subsequent servers.
            You can then use this node as initial peer for subsequent servers.
 
 
+        :param checkpoint_dir: directory to save expert checkpoints
+        :param load_experts: whether to load expert checkpoints from checkpoint_dir
+
         :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
         :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
             hosted on this server. For a more fine-grained compression, start server in python and specify compression
             hosted on this server. For a more fine-grained compression, start server in python and specify compression
             for each BatchTensorProto in ExpertBackend for the respective experts.
             for each BatchTensorProto in ExpertBackend for the respective experts.
@@ -100,8 +104,6 @@ class Server(threading.Thread):
         if len(kwargs) != 0:
         if len(kwargs) != 0:
             logger.info("Ignored kwargs:", kwargs)
             logger.info("Ignored kwargs:", kwargs)
         assert expert_cls in name_to_block
         assert expert_cls in name_to_block
-        assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
-            "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
 
 
         if no_dht:
         if no_dht:
             dht = None
             dht = None
@@ -110,7 +112,19 @@ class Server(threading.Thread):
             dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
             dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
             logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
             logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
 
 
-        # get expert uids
+        if load_experts:
+            assert dir_is_correct(checkpoint_dir)
+            assert expert_uids is None, "Can't both load saved experts and create new ones from given UIDs"
+            expert_uids = [child.name for child in checkpoint_dir.iterdir() if (child / 'checkpoint_last.pt').exists()]
+            if expert_uids:
+                logger.info(f"Located checkpoints for experts {expert_uids}, ignoring UID generation options")
+            else:
+                logger.info(f"No expert checkpoints found in {checkpoint_dir}, generating...")
+
+        assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
+            "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
+
+        # get expert uids if not loaded previously
         if expert_uids is None:
         if expert_uids is None:
             assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
             assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
             logger.info(f"Generating expert uids from pattern {expert_pattern}")
             logger.info(f"Generating expert uids from pattern {expert_pattern}")
@@ -138,6 +152,9 @@ class Server(threading.Thread):
                                                          opt=optim_cls(expert.parameters()),
                                                          opt=optim_cls(expert.parameters()),
                                                          max_batch_size=max_batch_size)
                                                          max_batch_size=max_batch_size)
 
 
+        if load_experts:
+            load_weights(experts, checkpoint_dir)
+
         server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
         server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
                         start=start)
                         start=start)
         return server
         return server

+ 22 - 6
hivemind/server/checkpoint_saver.py → hivemind/server/checkpoints.py

@@ -4,37 +4,53 @@ from pathlib import Path
 from shutil import copytree
 from shutil import copytree
 from tempfile import TemporaryDirectory
 from tempfile import TemporaryDirectory
 from typing import Dict
 from typing import Dict
+import os
 
 
 import torch
 import torch
 
 
 from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.expert_backend import ExpertBackend
 
 
 
 
+def dir_is_correct(directory: Path):
+    assert directory is not None
+    assert directory.exists()
+    assert directory.is_dir()
+    return True
+
+
 class CheckpointSaver(threading.Thread):
 class CheckpointSaver(threading.Thread):
     def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
     def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
         super().__init__()
         super().__init__()
+        assert dir_is_correct(checkpoint_dir)
         self.expert_backends = expert_backends
         self.expert_backends = expert_backends
         self.update_period = update_period
         self.update_period = update_period
         self.checkpoint_dir = checkpoint_dir
         self.checkpoint_dir = checkpoint_dir
         self.stop = threading.Event()
         self.stop = threading.Event()
 
 
+        # create expert directories to ensure that the directory is writable and checkpoints can be loaded
+        store_experts(self.expert_backends, self.checkpoint_dir)
+
     def run(self) -> None:
     def run(self) -> None:
         while not self.stop.wait(self.update_period):
         while not self.stop.wait(self.update_period):
             store_experts(self.expert_backends, self.checkpoint_dir)
             store_experts(self.expert_backends, self.checkpoint_dir)
 
 
 
 
-def store_experts(experts: Dict[str, ExpertBackend], checkpoints_dir: Path):
+def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
+    assert dir_is_correct(checkpoint_dir)
     timestamp = datetime.now().isoformat(sep='_')
     timestamp = datetime.now().isoformat(sep='_')
     with TemporaryDirectory() as tmpdirname:
     with TemporaryDirectory() as tmpdirname:
         for expert_name, expert_backend in experts.items():
         for expert_name, expert_backend in experts.items():
             expert_dir = Path(tmpdirname) / expert_name
             expert_dir = Path(tmpdirname) / expert_name
             expert_dir.mkdir()
             expert_dir.mkdir()
-            torch.save(expert_backend.state_dict(), expert_dir / f'checkpoint_{timestamp}.pt')
-        copytree(tmpdirname, str(checkpoints_dir), dirs_exist_ok=True)
+            checkpoint_name = expert_dir / f'checkpoint_{timestamp}.pt'
+            torch.save(expert_backend.state_dict(), checkpoint_name)
+            os.symlink(checkpoint_name, expert_dir / 'checkpoint_last.pt')
+        copytree(tmpdirname, str(checkpoint_dir), dirs_exist_ok=True)
 
 
 
 
-def load_weights(experts: Dict[str, ExpertBackend], checkpoints_dir: Path):
+def load_weights(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
+    assert dir_is_correct(checkpoint_dir)
     for expert_name, expert in experts.items():
     for expert_name, expert in experts.items():
-        checkpoints_folder = checkpoints_dir / expert_name
-        latest_checkpoint = max(checkpoints_folder.glob('checkpoint_*.pt'))
+        checkpoints_folder = checkpoint_dir / expert_name
+        latest_checkpoint = checkpoints_folder / 'checkpoint_last.pt'
         expert.load_state_dict(torch.load(latest_checkpoint))
         expert.load_state_dict(torch.load(latest_checkpoint))

+ 3 - 0
hivemind/server/expert_backend.py

@@ -61,6 +61,8 @@ class ExpertBackend(nn.Module):
         self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
         self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
         self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
         self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
 
 
+        self.register_buffer('update_count', torch.zeros(1, dtype=torch.long))
+
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         """
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
@@ -130,6 +132,7 @@ class ExpertBackend(nn.Module):
         """
         """
         self.opt.step()
         self.opt.step()
         self.opt.zero_grad()
         self.opt.zero_grad()
+        self.update_count += 1
 
 
     def get_info(self) -> Dict[str, Any]:
     def get_info(self) -> Dict[str, Any]:
         """ Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration. """
         """ Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration. """

+ 4 - 0
scripts/run_server.py

@@ -1,4 +1,5 @@
 from functools import partial
 from functools import partial
+from pathlib import Path
 
 
 import configargparse
 import configargparse
 import torch
 import torch
@@ -43,6 +44,9 @@ def main():
                              'a server can spawn before hitting "Too many open files"; Use at your own risk.')
                              'a server can spawn before hitting "Too many open files"; Use at your own risk.')
     parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression '
     parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression '
                         'parameter for grpc. Can be NONE, MEANSTD or FLOAT16')
                         'parameter for grpc. Can be NONE, MEANSTD or FLOAT16')
+    parser.add_argument('--checkpoint_dir', type=Path, required=False, help='Directory to store expert checkpoints')
+    parser.add_argument('--load_experts', action='store_true', help='Load experts from the checkpoint directory')
+
     # fmt:on
     # fmt:on
     args = vars(parser.parse_args())
     args = vars(parser.parse_args())
     args.pop('config', None)
     args.pop('config', None)

+ 46 - 10
tests/test_checkpoints.py

@@ -1,13 +1,19 @@
 from pathlib import Path
 from pathlib import Path
 from tempfile import TemporaryDirectory
 from tempfile import TemporaryDirectory
 
 
+import pytest
 import torch
 import torch
 from torch.nn import Linear
 from torch.nn import Linear
 
 
 from hivemind import BatchTensorDescriptor, ExpertBackend
 from hivemind import BatchTensorDescriptor, ExpertBackend
-from hivemind.server.checkpoint_saver import store_experts, load_weights
+from hivemind.server.checkpoints import store_experts, load_weights
 
 
+EXPERT_WEIGHT_UPDATES = 3
+BACKWARD_PASSES_BEFORE_SAVE = 2
+BACKWARD_PASSES_AFTER_SAVE = 2
 
 
+
+@pytest.mark.forked
 def test_save_load_checkpoints():
 def test_save_load_checkpoints():
     experts = {}
     experts = {}
     expert = Linear(1, 1)
     expert = Linear(1, 1)
@@ -22,19 +28,49 @@ def test_save_load_checkpoints():
     with TemporaryDirectory() as tmpdir:
     with TemporaryDirectory() as tmpdir:
         tmp_path = Path(tmpdir)
         tmp_path = Path(tmpdir)
 
 
-        expert.weight.data[0] = 1
-        store_experts(experts, tmp_path)
-        expert.weight.data[0] = 2
-        store_experts(experts, tmp_path)
-        expert.weight.data[0] = 3
-        store_experts(experts, tmp_path)
+        for i in range(1, EXPERT_WEIGHT_UPDATES + 1):
+            expert.weight.data[0] = i
+            store_experts(experts, tmp_path)
 
 
         checkpoints_dir = tmp_path / expert_name
         checkpoints_dir = tmp_path / expert_name
 
 
         assert checkpoints_dir.exists()
         assert checkpoints_dir.exists()
-        assert len(list(checkpoints_dir.iterdir())) == 3
+        # include checkpoint_last.pt
+        assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1
+
+        expert.weight.data[0] = 0
+
+        load_weights(experts, tmp_path)
+        assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES
+
+
+@pytest.mark.forked
+def test_restore_update_count():
+    experts = {}
+    expert = Linear(1, 1)
+    opt = torch.optim.SGD(expert.parameters(), 0.0)
+    expert_name = f'test_expert'
+    args_schema = (BatchTensorDescriptor(1),)
+    expert_backend = ExpertBackend(name=expert_name, expert=expert, opt=opt,
+                                   args_schema=args_schema,
+                                   outputs_schema=BatchTensorDescriptor(1),
+                                   max_batch_size=1,
+                                   )
+    experts[expert_name] = expert_backend
+
+    batch = torch.randn(1, 1)
+    loss_grad = torch.randn(1, 1)
+
+    with TemporaryDirectory() as tmpdir:
+        tmp_path = Path(tmpdir)
+
+        for _ in range(BACKWARD_PASSES_BEFORE_SAVE):
+            expert_backend.backward(batch, loss_grad)
+
+        store_experts(experts, tmp_path)
 
 
-        expert.weight.data[0] = 4
+        for _ in range(BACKWARD_PASSES_AFTER_SAVE):
+            expert_backend.backward(batch, loss_grad)
 
 
         load_weights(experts, tmp_path)
         load_weights(experts, tmp_path)
-        assert expert.weight.data[0] == 3
+        assert experts[expert_name].update_count == BACKWARD_PASSES_BEFORE_SAVE