test_training.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. #%env CUDA_VISIBLE_DEVICES=
  2. import argparse
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from hivemind import RemoteExpert, find_open_port
  8. from test_utils.run_server import background_server
  9. from sklearn.datasets import load_digits
  10. def test_training(port: Optional[int] = None, max_steps: int = 100, threshold: float = 0.9):
  11. if port is None:
  12. port = find_open_port()
  13. dataset = load_digits()
  14. X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
  15. with background_server(num_experts=2, device='cpu', port=port, hidden_dim=64):
  16. expert1 = RemoteExpert('expert.0', host='127.0.0.1', port=port)
  17. expert2 = RemoteExpert('expert.1', host='127.0.0.1', port=port)
  18. model = nn.Sequential(expert2, nn.Tanh(), expert1, nn.Linear(64, 10))
  19. opt = torch.optim.SGD(model.parameters(), lr=0.05)
  20. for step in range(max_steps):
  21. opt.zero_grad()
  22. outputs = model(X_train)
  23. loss = F.cross_entropy(outputs, y_train)
  24. loss.backward()
  25. opt.step()
  26. accuracy = (outputs.argmax(dim=1) == y_train).numpy().mean()
  27. if accuracy >= threshold:
  28. break
  29. assert accuracy >= threshold, f"too small accuracy: {accuracy}"