test_training.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from functools import partial
  2. import time
  3. import pytest
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from sklearn.datasets import load_digits
  8. from hivemind import RemoteExpert, background_server, DHT, DecentralizedSGD
  9. @pytest.mark.forked
  10. def test_training(max_steps: int = 100, threshold: float = 0.9):
  11. dataset = load_digits(n_class=2)
  12. X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
  13. SGD = partial(torch.optim.SGD, lr=0.05)
  14. with background_server(num_experts=2, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1,
  15. no_dht=True) as (server_endpoint, dht_endpoint):
  16. expert1 = RemoteExpert('expert.0', server_endpoint)
  17. expert2 = RemoteExpert('expert.1', server_endpoint)
  18. model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
  19. opt = torch.optim.SGD(model.parameters(), lr=0.05)
  20. for step in range(max_steps):
  21. opt.zero_grad()
  22. outputs = model(X_train)
  23. loss = F.cross_entropy(outputs, y_train)
  24. loss.backward()
  25. opt.step()
  26. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  27. if accuracy >= threshold:
  28. break
  29. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  30. @pytest.mark.forked
  31. def test_decentralized_optimizer_step():
  32. dht_root = DHT(start=True)
  33. initial_peers = [f"127.0.0.1:{dht_root.port}"]
  34. param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
  35. opt1 = DecentralizedSGD([param1], lr=0.1, dht=DHT(initial_peers=initial_peers, start=True),
  36. prefix='foo', target_group_size=2, verbose=True)
  37. param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
  38. opt2 = DecentralizedSGD([param2], lr=0.05, dht=DHT(initial_peers=initial_peers, start=True),
  39. prefix='foo', target_group_size=2, verbose=True)
  40. assert not torch.allclose(param1, param2)
  41. (param1.sum() + 300 * param2.sum()).backward()
  42. opt1.step()
  43. opt2.step()
  44. time.sleep(0.5)
  45. assert torch.allclose(param1, param2)
  46. reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300)
  47. assert torch.allclose(param1, torch.full_like(param1, reference))