checkpoints.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import threading
  2. from datetime import datetime
  3. from pathlib import Path
  4. from shutil import copy2
  5. from tempfile import TemporaryDirectory
  6. from typing import Dict
  7. import os
  8. import torch
  9. from hivemind.server.expert_backend import ExpertBackend
  10. def dir_is_correct(directory: Path):
  11. assert directory is not None
  12. assert directory.exists()
  13. assert directory.is_dir()
  14. return True
  15. def copy_tree(src: str, dst: str):
  16. if not os.path.exists(dst):
  17. os.makedirs(dst)
  18. for item in os.listdir(src):
  19. src_entry = os.path.join(src, item)
  20. dst_entry = os.path.join(dst, item)
  21. if os.path.isdir(src_entry):
  22. copy_tree(src_entry, dst_entry)
  23. else:
  24. copy2(src_entry, dst_entry)
  25. class CheckpointSaver(threading.Thread):
  26. def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
  27. super().__init__()
  28. assert dir_is_correct(checkpoint_dir)
  29. self.expert_backends = expert_backends
  30. self.update_period = update_period
  31. self.checkpoint_dir = checkpoint_dir
  32. self.stop = threading.Event()
  33. # create expert directories to ensure that the directory is writable and checkpoints can be loaded
  34. store_experts(self.expert_backends, self.checkpoint_dir)
  35. def run(self) -> None:
  36. while not self.stop.wait(self.update_period):
  37. store_experts(self.expert_backends, self.checkpoint_dir)
  38. def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
  39. assert dir_is_correct(checkpoint_dir)
  40. timestamp = datetime.now().isoformat(sep='_')
  41. with TemporaryDirectory() as tmpdirname:
  42. for expert_name, expert_backend in experts.items():
  43. expert_dir = Path(tmpdirname) / expert_name
  44. expert_dir.mkdir()
  45. checkpoint_name = expert_dir / f'checkpoint_{timestamp}.pt'
  46. torch.save(expert_backend.state_dict(), checkpoint_name)
  47. os.symlink(checkpoint_name, expert_dir / 'checkpoint_last.pt')
  48. copy_tree(tmpdirname, str(checkpoint_dir))
  49. def load_weights(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
  50. assert dir_is_correct(checkpoint_dir)
  51. for expert_name, expert in experts.items():
  52. checkpoints_folder = checkpoint_dir / expert_name
  53. latest_checkpoint = checkpoints_folder / 'checkpoint_last.pt'
  54. expert.load_state_dict(torch.load(latest_checkpoint))