Prechádzať zdrojové kódy

Save checkpoints in Server (#67)

* Save checkpoints periodically

* Increase socket queue

* Graceful shutdown for threads
Max Ryabinin 5 rokov pred
rodič
commit
51f7182353

+ 15 - 3
hivemind/server/__init__.py

@@ -3,10 +3,12 @@ import os
 import threading
 from socket import socket, AF_INET, SOCK_STREAM, SO_REUSEADDR, SOL_SOCKET, timeout
 from typing import Dict, Optional
+
 import torch
 
 from .connection_handler import handle_connection
 from .dht_handler import DHTHandlerThread
+from .checkpoint_saver import CheckpointSaver
 from ..dht import DHT
 from ..runtime import Runtime, ExpertBackend
 
@@ -36,11 +38,15 @@ class Server(threading.Thread):
 
     def __init__(self, dht: Optional[DHT], expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
                  port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False,
-                 **kwargs):
+                 checkpoint_dir=None, **kwargs):
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         self.addr, self.port = addr, port
         self.conn_handlers = self._create_connection_handlers(conn_handler_processes)
+        if checkpoint_dir is not None:
+            self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
+        else:
+            self.checkpoint_saver = None
         self.runtime = Runtime(self.experts, **kwargs)
 
         if start:
@@ -58,6 +64,8 @@ class Server(threading.Thread):
             dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht,
                                                   addr=self.addr, port=self.port, update_period=self.update_period)
             dht_handler_thread.start()
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.start()
 
         for process in self.conn_handlers:
             if not process.is_alive():
@@ -68,7 +76,11 @@ class Server(threading.Thread):
         for process in self.conn_handlers:
             process.join()
         if self.dht:
+            dht_handler_thread.stop = True
             dht_handler_thread.join()
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.stop = True
+            self.checkpoint_saver.join()
 
     def run_in_background(self, await_ready=True, timeout=None):
         """
@@ -96,12 +108,12 @@ class Server(threading.Thread):
         sock = socket(AF_INET, SOCK_STREAM)
         sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
         sock.bind(('', self.port))
-        sock.listen()
+        sock.listen(1024)
         sock.settimeout(self.update_period)
 
         processes = [mp.context.ForkProcess(
             target=socket_loop, name=f"socket_loop-{i}", args=(sock, self.experts), daemon=True)
-                     for i in range(num_handlers)]
+            for i in range(num_handlers)]
         return processes
 
     def shutdown(self):

+ 42 - 0
hivemind/server/checkpoint_saver.py

@@ -0,0 +1,42 @@
+import threading
+import time
+from datetime import datetime
+from pathlib import Path
+from shutil import copytree
+from tempfile import TemporaryDirectory
+from typing import Dict
+
+import torch
+
+from ..runtime import ExpertBackend
+
+
+class CheckpointSaver(threading.Thread):
+    def __init__(self, expert_backends: Dict[str, ExpertBackend], dir: Path, update_period: int):
+        super().__init__()
+        self.expert_backends = expert_backends
+        self.update_period = update_period
+        self.dir = dir
+        self.stop = False
+
+    def run(self) -> None:
+        while not self.stop:
+            store_experts(self.expert_backends, self.dir)
+            time.sleep(self.update_period)
+
+
+def store_experts(experts: Dict[str, ExpertBackend], checkpoints_dir: Path):
+    timestamp = datetime.now().isoformat(sep='_')
+    with TemporaryDirectory() as tmpdirname:
+        for expert_name, expert_backend in experts.items():
+            expert_dir = Path(tmpdirname) / expert_name
+            expert_dir.mkdir()
+            torch.save(expert_backend.state_dict(), expert_dir / f'checkpoint_{timestamp}.pt')
+        copytree(tmpdirname, str(checkpoints_dir), dirs_exist_ok=True)
+
+
+def load_weights(experts: Dict[str, ExpertBackend], checkpoints_dir: Path):
+    for expert_name, expert in experts.items():
+        checkpoints_folder = checkpoints_dir / expert_name
+        latest_checkpoint = max(checkpoints_folder.glob('checkpoint_*.pt'))
+        expert.load_state_dict(torch.load(latest_checkpoint))

+ 2 - 1
hivemind/server/dht_handler.py

@@ -13,8 +13,9 @@ class DHTHandlerThread(threading.Thread):
         self.experts = experts
         self.dht = dht
         self.update_period = update_period
+        self.stop = False
 
     def run(self) -> None:
-        while True:
+        while not self.stop:
             self.dht.declare_experts(self.experts.keys(), self.addr, self.port)
             time.sleep(self.update_period)

+ 40 - 0
tests/test_checkpoints.py

@@ -0,0 +1,40 @@
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import torch
+from torch.nn import Linear
+
+from hivemind import BatchTensorDescriptor, ExpertBackend
+from hivemind.server.checkpoint_saver import store_experts, load_weights
+
+
+def test_save_load_checkpoints():
+    experts = {}
+    expert = Linear(1, 1)
+    opt = torch.optim.SGD(expert.parameters(), 0.0)
+    expert_name = f'test_expert'
+    args_schema = (BatchTensorDescriptor(1),)
+    experts[expert_name] = ExpertBackend(name=expert_name, expert=expert, opt=opt,
+                                         args_schema=args_schema,
+                                         outputs_schema=BatchTensorDescriptor(1),
+                                         max_batch_size=1,
+                                         )
+    with TemporaryDirectory() as tmpdir:
+        tmp_path = Path(tmpdir)
+
+        expert.weight.data[0] = 1
+        store_experts(experts, tmp_path)
+        expert.weight.data[0] = 2
+        store_experts(experts, tmp_path)
+        expert.weight.data[0] = 3
+        store_experts(experts, tmp_path)
+
+        checkpoints_dir = tmp_path / expert_name
+
+        assert checkpoints_dir.exists()
+        assert len(list(checkpoints_dir.iterdir())) == 3
+
+        expert.weight.data[0] = 4
+
+        load_weights(experts, tmp_path)
+        assert expert.weight.data[0] == 3