test_checkpoints.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from pathlib import Path
  2. from tempfile import TemporaryDirectory
  3. import pytest
  4. import torch
  5. from torch.nn import Linear
  6. from hivemind import BatchTensorDescriptor, ExpertBackend
  7. from hivemind.server.checkpoints import store_experts, load_weights
  8. EXPERT_WEIGHT_UPDATES = 3
  9. BACKWARD_PASSES_BEFORE_SAVE = 2
  10. BACKWARD_PASSES_AFTER_SAVE = 2
  11. @pytest.mark.forked
  12. def test_save_load_checkpoints():
  13. experts = {}
  14. expert = Linear(1, 1)
  15. opt = torch.optim.SGD(expert.parameters(), 0.0)
  16. expert_name = f'test_expert'
  17. args_schema = (BatchTensorDescriptor(1),)
  18. experts[expert_name] = ExpertBackend(name=expert_name, expert=expert, opt=opt,
  19. args_schema=args_schema,
  20. outputs_schema=BatchTensorDescriptor(1),
  21. max_batch_size=1,
  22. )
  23. with TemporaryDirectory() as tmpdir:
  24. tmp_path = Path(tmpdir)
  25. for i in range(1, EXPERT_WEIGHT_UPDATES + 1):
  26. expert.weight.data[0] = i
  27. store_experts(experts, tmp_path)
  28. checkpoints_dir = tmp_path / expert_name
  29. assert checkpoints_dir.exists()
  30. # include checkpoint_last.pt
  31. assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1
  32. expert.weight.data[0] = 0
  33. load_weights(experts, tmp_path)
  34. assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES
  35. @pytest.mark.forked
  36. def test_restore_update_count():
  37. experts = {}
  38. expert = Linear(1, 1)
  39. opt = torch.optim.SGD(expert.parameters(), 0.0)
  40. expert_name = f'test_expert'
  41. args_schema = (BatchTensorDescriptor(1),)
  42. expert_backend = ExpertBackend(name=expert_name, expert=expert, opt=opt,
  43. args_schema=args_schema,
  44. outputs_schema=BatchTensorDescriptor(1),
  45. max_batch_size=1,
  46. )
  47. experts[expert_name] = expert_backend
  48. batch = torch.randn(1, 1)
  49. loss_grad = torch.randn(1, 1)
  50. with TemporaryDirectory() as tmpdir:
  51. tmp_path = Path(tmpdir)
  52. for _ in range(BACKWARD_PASSES_BEFORE_SAVE):
  53. expert_backend.backward(batch, loss_grad)
  54. store_experts(experts, tmp_path)
  55. for _ in range(BACKWARD_PASSES_AFTER_SAVE):
  56. expert_backend.backward(batch, loss_grad)
  57. load_weights(experts, tmp_path)
  58. assert experts[expert_name].update_count == BACKWARD_PASSES_BEFORE_SAVE