test_expert_backend.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from pathlib import Path
  2. from tempfile import TemporaryDirectory
  3. import pytest
  4. import torch
  5. from torch.nn import Linear
  6. from hivemind import BatchTensorDescriptor, ExpertBackend
  7. from hivemind.server.checkpoints import store_experts, load_experts
  8. from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
  9. EXPERT_WEIGHT_UPDATES = 3
  10. BACKWARD_PASSES_BEFORE_SAVE = 2
  11. BACKWARD_PASSES_AFTER_SAVE = 2
  12. EXPERT_NAME = 'test_expert'
  13. PEAK_LR = 1.0
  14. @pytest.fixture
  15. def example_experts():
  16. expert = Linear(1, 1)
  17. opt = torch.optim.SGD(expert.parameters(), PEAK_LR)
  18. args_schema = (BatchTensorDescriptor(1),)
  19. expert_backend = ExpertBackend(name=EXPERT_NAME, expert=expert, optimizer=opt,
  20. scheduler=get_linear_schedule_with_warmup,
  21. num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
  22. num_training_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
  23. args_schema=args_schema, outputs_schema=BatchTensorDescriptor(1), max_batch_size=1,
  24. )
  25. experts = {EXPERT_NAME: expert_backend}
  26. yield experts
  27. @pytest.mark.forked
  28. def test_save_load_checkpoints(example_experts):
  29. expert = example_experts[EXPERT_NAME].expert
  30. with TemporaryDirectory() as tmpdir:
  31. tmp_path = Path(tmpdir)
  32. for i in range(1, EXPERT_WEIGHT_UPDATES + 1):
  33. expert.weight.data[0] = i
  34. store_experts(example_experts, tmp_path)
  35. checkpoints_dir = tmp_path / EXPERT_NAME
  36. assert checkpoints_dir.exists()
  37. # include checkpoint_last.pt
  38. assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1
  39. expert.weight.data[0] = 0
  40. load_experts(example_experts, tmp_path)
  41. assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES
  42. @pytest.mark.forked
  43. def test_restore_update_count(example_experts):
  44. expert_backend = example_experts[EXPERT_NAME]
  45. batch = torch.randn(1, 1)
  46. loss_grad = torch.randn(1, 1)
  47. with TemporaryDirectory() as tmpdir:
  48. tmp_path = Path(tmpdir)
  49. for _ in range(BACKWARD_PASSES_BEFORE_SAVE):
  50. expert_backend.backward(batch, loss_grad)
  51. store_experts(example_experts, tmp_path)
  52. for _ in range(BACKWARD_PASSES_AFTER_SAVE):
  53. expert_backend.backward(batch, loss_grad)
  54. load_experts(example_experts, tmp_path)
  55. assert expert_backend.update_count == BACKWARD_PASSES_BEFORE_SAVE
  56. @pytest.mark.forked
  57. def test_lr_schedule(example_experts):
  58. expert_backend = example_experts[EXPERT_NAME]
  59. optimizer = expert_backend.optimizer
  60. batch = torch.randn(1, 1)
  61. loss_grad = torch.randn(1, 1)
  62. with TemporaryDirectory() as tmpdir:
  63. tmp_path = Path(tmpdir)
  64. assert optimizer.param_groups[0]['lr'] == 0.0
  65. for i in range(BACKWARD_PASSES_BEFORE_SAVE):
  66. assert optimizer.param_groups[0]['lr'] == PEAK_LR * i / BACKWARD_PASSES_BEFORE_SAVE
  67. expert_backend.backward(batch, loss_grad)
  68. assert optimizer.param_groups[0]['lr'] == PEAK_LR
  69. store_experts(example_experts, tmp_path)
  70. for i in range(BACKWARD_PASSES_AFTER_SAVE):
  71. assert optimizer.param_groups[0]['lr'] == PEAK_LR * (1 - (i / BACKWARD_PASSES_AFTER_SAVE))
  72. expert_backend.backward(batch, loss_grad)
  73. assert optimizer.param_groups[0]['lr'] == 0.0
  74. load_experts(example_experts, tmp_path)
  75. assert optimizer.param_groups[0]['lr'] == PEAK_LR