test_checkpoints.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from pathlib import Path
  2. from tempfile import TemporaryDirectory
  3. import torch
  4. from torch.nn import Linear
  5. from hivemind import BatchTensorDescriptor, ExpertBackend
  6. from hivemind.server.checkpoint_saver import store_experts, load_weights
  7. def test_save_load_checkpoints():
  8. experts = {}
  9. expert = Linear(1, 1)
  10. opt = torch.optim.SGD(expert.parameters(), 0.0)
  11. expert_name = f'test_expert'
  12. args_schema = (BatchTensorDescriptor(1),)
  13. experts[expert_name] = ExpertBackend(name=expert_name, expert=expert, opt=opt,
  14. args_schema=args_schema,
  15. outputs_schema=BatchTensorDescriptor(1),
  16. max_batch_size=1,
  17. )
  18. with TemporaryDirectory() as tmpdir:
  19. tmp_path = Path(tmpdir)
  20. expert.weight.data[0] = 1
  21. store_experts(experts, tmp_path)
  22. expert.weight.data[0] = 2
  23. store_experts(experts, tmp_path)
  24. expert.weight.data[0] = 3
  25. store_experts(experts, tmp_path)
  26. checkpoints_dir = tmp_path / expert_name
  27. assert checkpoints_dir.exists()
  28. assert len(list(checkpoints_dir.iterdir())) == 3
  29. expert.weight.data[0] = 4
  30. load_weights(experts, tmp_path)
  31. assert expert.weight.data[0] == 3