test_training.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from functools import partial
  2. from typing import Optional
  3. import pytest
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from sklearn.datasets import load_digits
  8. from hivemind import RemoteExpert, background_server
  9. @pytest.mark.forked
  10. def test_training(port: Optional[int] = None, max_steps: int = 100, threshold: float = 0.9):
  11. dataset = load_digits()
  12. X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
  13. SGD = partial(torch.optim.SGD, lr=0.05)
  14. with background_server(num_experts=2, device='cpu', optim_cls=SGD, hidden_dim=64) as (server_endpoint, _):
  15. expert1 = RemoteExpert('expert.0', server_endpoint)
  16. expert2 = RemoteExpert('expert.1', server_endpoint)
  17. model = nn.Sequential(expert2, nn.Tanh(), expert1, nn.Linear(64, 10))
  18. opt = torch.optim.SGD(model.parameters(), lr=0.05)
  19. for step in range(max_steps):
  20. opt.zero_grad()
  21. outputs = model(X_train)
  22. loss = F.cross_entropy(outputs, y_train)
  23. loss.backward()
  24. opt.step()
  25. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  26. if accuracy >= threshold:
  27. break
  28. assert accuracy >= threshold, f"too small accuracy: {accuracy}"