123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- from pathlib import Path
- from tempfile import TemporaryDirectory
- import pytest
- import torch
- from torch.nn import Linear
- from hivemind import BatchTensorDescriptor, ExpertBackend
- from hivemind.server.checkpoints import store_experts, load_experts
- from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
- EXPERT_WEIGHT_UPDATES = 3
- BACKWARD_PASSES_BEFORE_SAVE = 2
- BACKWARD_PASSES_AFTER_SAVE = 2
- EXPERT_NAME = 'test_expert'
- PEAK_LR = 1.0
- @pytest.fixture
- def example_experts():
- expert = Linear(1, 1)
- opt = torch.optim.SGD(expert.parameters(), PEAK_LR)
- args_schema = (BatchTensorDescriptor(1),)
- expert_backend = ExpertBackend(name=EXPERT_NAME, expert=expert, optimizer=opt,
- scheduler=get_linear_schedule_with_warmup,
- num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
- num_total_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
- args_schema=args_schema, outputs_schema=BatchTensorDescriptor(1), max_batch_size=1,
- )
- experts = {EXPERT_NAME: expert_backend}
- yield experts
- @pytest.mark.forked
- def test_save_load_checkpoints(example_experts):
- expert = example_experts[EXPERT_NAME].expert
- with TemporaryDirectory() as tmpdir:
- tmp_path = Path(tmpdir)
- for i in range(1, EXPERT_WEIGHT_UPDATES + 1):
- expert.weight.data[0] = i
- store_experts(example_experts, tmp_path)
- checkpoints_dir = tmp_path / EXPERT_NAME
- assert checkpoints_dir.exists()
- # include checkpoint_last.pt
- assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1
- expert.weight.data[0] = 0
- load_experts(example_experts, tmp_path)
- assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES
- @pytest.mark.forked
- def test_restore_update_count(example_experts):
- expert_backend = example_experts[EXPERT_NAME]
- batch = torch.randn(1, 1)
- loss_grad = torch.randn(1, 1)
- with TemporaryDirectory() as tmpdir:
- tmp_path = Path(tmpdir)
- for _ in range(BACKWARD_PASSES_BEFORE_SAVE):
- expert_backend.backward(batch, loss_grad)
- store_experts(example_experts, tmp_path)
- for _ in range(BACKWARD_PASSES_AFTER_SAVE):
- expert_backend.backward(batch, loss_grad)
- load_experts(example_experts, tmp_path)
- assert expert_backend.update_count == BACKWARD_PASSES_BEFORE_SAVE
- @pytest.mark.forked
- def test_lr_schedule(example_experts):
- expert_backend = example_experts[EXPERT_NAME]
- optimizer = expert_backend.optimizer
- batch = torch.randn(1, 1)
- loss_grad = torch.randn(1, 1)
- with TemporaryDirectory() as tmpdir:
- tmp_path = Path(tmpdir)
- assert optimizer.param_groups[0]['lr'] == 0.0
- for i in range(BACKWARD_PASSES_BEFORE_SAVE):
- assert optimizer.param_groups[0]['lr'] == PEAK_LR * i / BACKWARD_PASSES_BEFORE_SAVE
- expert_backend.backward(batch, loss_grad)
- assert optimizer.param_groups[0]['lr'] == PEAK_LR
- store_experts(example_experts, tmp_path)
- for i in range(BACKWARD_PASSES_AFTER_SAVE):
- assert optimizer.param_groups[0]['lr'] == PEAK_LR * (1 - (i / BACKWARD_PASSES_AFTER_SAVE))
- expert_backend.backward(batch, loss_grad)
- assert optimizer.param_groups[0]['lr'] == 0.0
- load_experts(example_experts, tmp_path)
- assert optimizer.param_groups[0]['lr'] == PEAK_LR
|