test_expert_backend.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from pathlib import Path
  2. from tempfile import TemporaryDirectory
  3. import pytest
  4. import torch
  5. from torch.nn import Linear
  6. from hivemind import BatchTensorDescriptor, ExpertBackend
  7. from hivemind.moe.server.checkpoints import store_experts, load_experts
  8. from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
  9. EXPERT_WEIGHT_UPDATES = 3
  10. BACKWARD_PASSES_BEFORE_SAVE = 2
  11. BACKWARD_PASSES_AFTER_SAVE = 2
  12. EXPERT_NAME = "test_expert"
  13. PEAK_LR = 1.0
  14. @pytest.fixture
  15. def example_experts():
  16. expert = Linear(1, 1)
  17. opt = torch.optim.SGD(expert.parameters(), PEAK_LR)
  18. args_schema = (BatchTensorDescriptor(1),)
  19. expert_backend = ExpertBackend(
  20. name=EXPERT_NAME,
  21. expert=expert,
  22. optimizer=opt,
  23. scheduler=get_linear_schedule_with_warmup,
  24. num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
  25. num_total_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
  26. args_schema=args_schema,
  27. outputs_schema=BatchTensorDescriptor(1),
  28. max_batch_size=1,
  29. )
  30. experts = {EXPERT_NAME: expert_backend}
  31. yield experts
  32. @pytest.mark.forked
  33. def test_save_load_checkpoints(example_experts):
  34. expert = example_experts[EXPERT_NAME].expert
  35. with TemporaryDirectory() as tmpdir:
  36. tmp_path = Path(tmpdir)
  37. for i in range(1, EXPERT_WEIGHT_UPDATES + 1):
  38. expert.weight.data[0] = i
  39. store_experts(example_experts, tmp_path)
  40. checkpoints_dir = tmp_path / EXPERT_NAME
  41. assert checkpoints_dir.exists()
  42. # include checkpoint_last.pt
  43. assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1
  44. expert.weight.data[0] = 0
  45. load_experts(example_experts, tmp_path)
  46. assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES
  47. @pytest.mark.forked
  48. def test_restore_update_count(example_experts):
  49. expert_backend = example_experts[EXPERT_NAME]
  50. batch = torch.randn(1, 1)
  51. loss_grad = torch.randn(1, 1)
  52. with TemporaryDirectory() as tmpdir:
  53. tmp_path = Path(tmpdir)
  54. for _ in range(BACKWARD_PASSES_BEFORE_SAVE):
  55. expert_backend.backward(batch, loss_grad)
  56. store_experts(example_experts, tmp_path)
  57. for _ in range(BACKWARD_PASSES_AFTER_SAVE):
  58. expert_backend.backward(batch, loss_grad)
  59. load_experts(example_experts, tmp_path)
  60. assert expert_backend.update_count == BACKWARD_PASSES_BEFORE_SAVE
  61. @pytest.mark.forked
  62. def test_lr_schedule(example_experts):
  63. expert_backend = example_experts[EXPERT_NAME]
  64. optimizer = expert_backend.optimizer
  65. batch = torch.randn(1, 1)
  66. loss_grad = torch.randn(1, 1)
  67. with TemporaryDirectory() as tmpdir:
  68. tmp_path = Path(tmpdir)
  69. assert optimizer.param_groups[0]["lr"] == 0.0
  70. for i in range(BACKWARD_PASSES_BEFORE_SAVE):
  71. assert optimizer.param_groups[0]["lr"] == PEAK_LR * i / BACKWARD_PASSES_BEFORE_SAVE
  72. expert_backend.backward(batch, loss_grad)
  73. assert optimizer.param_groups[0]["lr"] == PEAK_LR
  74. store_experts(example_experts, tmp_path)
  75. for i in range(BACKWARD_PASSES_AFTER_SAVE):
  76. assert optimizer.param_groups[0]["lr"] == PEAK_LR * (1 - (i / BACKWARD_PASSES_AFTER_SAVE))
  77. expert_backend.backward(batch, loss_grad)
  78. assert optimizer.param_groups[0]["lr"] == 0.0
  79. load_experts(example_experts, tmp_path)
  80. assert optimizer.param_groups[0]["lr"] == PEAK_LR