test_expert_backend.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from pathlib import Path
  2. from tempfile import TemporaryDirectory
  3. import pytest
  4. import torch
  5. from torch.nn import Linear
  6. from hivemind import BatchTensorDescriptor, ModuleBackend
  7. from hivemind.moe.server.checkpoints import load_experts, store_experts
  8. from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
  9. EXPERT_WEIGHT_UPDATES = 3
  10. BACKWARD_PASSES_BEFORE_SAVE = 2
  11. BACKWARD_PASSES_AFTER_SAVE = 2
  12. EXPERT_NAME = "test_expert"
  13. PEAK_LR = 1.0
  14. @pytest.fixture
  15. def example_experts():
  16. expert = Linear(1, 1)
  17. opt = torch.optim.SGD(expert.parameters(), PEAK_LR)
  18. args_schema = (BatchTensorDescriptor(1),)
  19. expert_backend = ModuleBackend(
  20. name=EXPERT_NAME,
  21. module=expert,
  22. optimizer=opt,
  23. scheduler=get_linear_schedule_with_warmup(
  24. opt,
  25. num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
  26. num_training_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
  27. ),
  28. args_schema=args_schema,
  29. outputs_schema=BatchTensorDescriptor(1),
  30. max_batch_size=1,
  31. )
  32. experts = {EXPERT_NAME: expert_backend}
  33. yield experts
  34. @pytest.mark.forked
  35. def test_save_load_checkpoints(example_experts):
  36. expert = example_experts[EXPERT_NAME].module
  37. with TemporaryDirectory() as tmpdir:
  38. tmp_path = Path(tmpdir)
  39. for i in range(1, EXPERT_WEIGHT_UPDATES + 1):
  40. expert.weight.data[0] = i
  41. store_experts(example_experts, tmp_path)
  42. checkpoints_dir = tmp_path / EXPERT_NAME
  43. assert checkpoints_dir.exists()
  44. # include checkpoint_last.pt
  45. assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1
  46. expert.weight.data[0] = 0
  47. load_experts(example_experts, tmp_path)
  48. assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES
  49. @pytest.mark.forked
  50. def test_restore_update_count(example_experts):
  51. expert_backend = example_experts[EXPERT_NAME]
  52. batch = torch.randn(1, 1)
  53. loss_grad = torch.randn(1, 1)
  54. with TemporaryDirectory() as tmpdir:
  55. tmp_path = Path(tmpdir)
  56. for _ in range(BACKWARD_PASSES_BEFORE_SAVE):
  57. expert_backend.backward(batch, loss_grad)
  58. store_experts(example_experts, tmp_path)
  59. for _ in range(BACKWARD_PASSES_AFTER_SAVE):
  60. expert_backend.backward(batch, loss_grad)
  61. load_experts(example_experts, tmp_path)
  62. assert expert_backend.scheduler._step_count == BACKWARD_PASSES_BEFORE_SAVE + 1
  63. @pytest.mark.forked
  64. def test_lr_schedule(example_experts):
  65. expert_backend = example_experts[EXPERT_NAME]
  66. optimizer = expert_backend.optimizer
  67. batch = torch.randn(1, 1)
  68. loss_grad = torch.randn(1, 1)
  69. with TemporaryDirectory() as tmpdir:
  70. tmp_path = Path(tmpdir)
  71. assert optimizer.param_groups[0]["lr"] == 0.0
  72. for i in range(BACKWARD_PASSES_BEFORE_SAVE):
  73. assert optimizer.param_groups[0]["lr"] == PEAK_LR * i / BACKWARD_PASSES_BEFORE_SAVE
  74. expert_backend.backward(batch, loss_grad)
  75. assert optimizer.param_groups[0]["lr"] == PEAK_LR
  76. store_experts(example_experts, tmp_path)
  77. for i in range(BACKWARD_PASSES_AFTER_SAVE):
  78. assert optimizer.param_groups[0]["lr"] == PEAK_LR * (1 - (i / BACKWARD_PASSES_AFTER_SAVE))
  79. expert_backend.backward(batch, loss_grad)
  80. assert optimizer.param_groups[0]["lr"] == 0.0
  81. load_experts(example_experts, tmp_path)
  82. assert optimizer.param_groups[0]["lr"] == PEAK_LR