123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- from pathlib import Path
- from tempfile import TemporaryDirectory
- import pytest
- import torch
- from torch.nn import Linear
- from hivemind import BatchTensorDescriptor, ModuleBackend
- from hivemind.moe.server.checkpoints import load_experts, store_experts
- from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
- EXPERT_WEIGHT_UPDATES = 3
- BACKWARD_PASSES_BEFORE_SAVE = 2
- BACKWARD_PASSES_AFTER_SAVE = 2
- EXPERT_NAME = "test_expert"
- PEAK_LR = 1.0
- @pytest.fixture
- def example_experts():
- expert = Linear(1, 1)
- opt = torch.optim.SGD(expert.parameters(), PEAK_LR)
- args_schema = (BatchTensorDescriptor(1),)
- expert_backend = ModuleBackend(
- name=EXPERT_NAME,
- module=expert,
- optimizer=opt,
- scheduler=get_linear_schedule_with_warmup(
- opt,
- num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
- num_training_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
- ),
- args_schema=args_schema,
- outputs_schema=BatchTensorDescriptor(1),
- max_batch_size=1,
- )
- experts = {EXPERT_NAME: expert_backend}
- yield experts
- @pytest.mark.forked
- def test_save_load_checkpoints(example_experts):
- expert = example_experts[EXPERT_NAME].module
- with TemporaryDirectory() as tmpdir:
- tmp_path = Path(tmpdir)
- for i in range(1, EXPERT_WEIGHT_UPDATES + 1):
- expert.weight.data[0] = i
- store_experts(example_experts, tmp_path)
- checkpoints_dir = tmp_path / EXPERT_NAME
- assert checkpoints_dir.exists()
- # include checkpoint_last.pt
- assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1
- expert.weight.data[0] = 0
- load_experts(example_experts, tmp_path)
- assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES
- @pytest.mark.forked
- def test_restore_update_count(example_experts):
- expert_backend = example_experts[EXPERT_NAME]
- batch = torch.randn(1, 1)
- loss_grad = torch.randn(1, 1)
- with TemporaryDirectory() as tmpdir:
- tmp_path = Path(tmpdir)
- for _ in range(BACKWARD_PASSES_BEFORE_SAVE):
- expert_backend.backward(batch, loss_grad)
- store_experts(example_experts, tmp_path)
- for _ in range(BACKWARD_PASSES_AFTER_SAVE):
- expert_backend.backward(batch, loss_grad)
- load_experts(example_experts, tmp_path)
- assert expert_backend.scheduler._step_count == BACKWARD_PASSES_BEFORE_SAVE + 1
- @pytest.mark.forked
- def test_lr_schedule(example_experts):
- expert_backend = example_experts[EXPERT_NAME]
- optimizer = expert_backend.optimizer
- batch = torch.randn(1, 1)
- loss_grad = torch.randn(1, 1)
- with TemporaryDirectory() as tmpdir:
- tmp_path = Path(tmpdir)
- assert optimizer.param_groups[0]["lr"] == 0.0
- for i in range(BACKWARD_PASSES_BEFORE_SAVE):
- assert optimizer.param_groups[0]["lr"] == PEAK_LR * i / BACKWARD_PASSES_BEFORE_SAVE
- expert_backend.backward(batch, loss_grad)
- assert optimizer.param_groups[0]["lr"] == PEAK_LR
- store_experts(example_experts, tmp_path)
- for i in range(BACKWARD_PASSES_AFTER_SAVE):
- assert optimizer.param_groups[0]["lr"] == PEAK_LR * (1 - (i / BACKWARD_PASSES_AFTER_SAVE))
- expert_backend.backward(batch, loss_grad)
- assert optimizer.param_groups[0]["lr"] == 0.0
- load_experts(example_experts, tmp_path)
- assert optimizer.param_groups[0]["lr"] == PEAK_LR
|