test_training.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import time
  2. from functools import partial
  3. import pytest
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from sklearn.datasets import load_digits
  8. from hivemind import DHT
  9. from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
  10. from hivemind.moe.server import background_server
  11. from hivemind.optim import DecentralizedSGD, DecentralizedAdam
  12. @pytest.mark.forked
  13. def test_training(max_steps: int = 100, threshold: float = 0.9):
  14. dataset = load_digits(n_class=2)
  15. X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
  16. SGD = partial(torch.optim.SGD, lr=0.05)
  17. with background_server(num_experts=2, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1,
  18. no_dht=True) as (server_endpoint, _):
  19. expert1 = RemoteExpert('expert.0', server_endpoint)
  20. expert2 = RemoteExpert('expert.1', server_endpoint)
  21. model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
  22. opt = SGD(model.parameters(), lr=0.05)
  23. for step in range(max_steps):
  24. outputs = model(X_train)
  25. loss = F.cross_entropy(outputs, y_train)
  26. loss.backward()
  27. opt.step()
  28. opt.zero_grad()
  29. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  30. if accuracy >= threshold:
  31. break
  32. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  33. @pytest.mark.forked
  34. def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=2):
  35. dataset = load_digits(n_class=2)
  36. X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
  37. SGD = partial(torch.optim.SGD, lr=0.05)
  38. all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
  39. with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1) \
  40. as (server_endpoint, dht_maddrs):
  41. dht = DHT(start=True, initial_peers=dht_maddrs)
  42. moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix='expert.', k_best=2)
  43. model = nn.Sequential(moe, nn.Linear(64, 2))
  44. opt = SGD(model.parameters(), lr=0.05)
  45. for step in range(max_steps):
  46. outputs = model(X_train)
  47. loss = F.cross_entropy(outputs, y_train)
  48. loss.backward()
  49. opt.step()
  50. opt.zero_grad()
  51. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  52. if accuracy >= threshold:
  53. break
  54. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  55. class SwitchNetwork(nn.Module):
  56. def __init__(self, dht, in_features, num_classes, num_experts):
  57. super().__init__()
  58. self.moe = RemoteSwitchMixtureOfExperts(in_features=in_features, grid_size=(num_experts,), dht=dht,
  59. jitter_eps=0, uid_prefix='expert.', k_best=1,
  60. k_min=1)
  61. self.linear = nn.Linear(in_features, num_classes)
  62. def forward(self, x):
  63. moe_output, balancing_loss = self.moe(x)
  64. return self.linear(moe_output), balancing_loss
  65. @pytest.mark.forked
  66. def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_experts=5):
  67. dataset = load_digits(n_class=2)
  68. X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
  69. SGD = partial(torch.optim.SGD, lr=0.05)
  70. all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
  71. with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64,
  72. num_handlers=1) as (server_endpoint, dht_maddrs):
  73. dht = DHT(start=True, initial_peers=dht_maddrs)
  74. model = SwitchNetwork(dht, 64, 2, num_experts)
  75. opt = SGD(model.parameters(), lr=0.05)
  76. for step in range(max_steps):
  77. outputs, balancing_loss = model(X_train)
  78. loss = F.cross_entropy(outputs, y_train) + 0.01 * balancing_loss
  79. loss.backward()
  80. opt.step()
  81. opt.zero_grad()
  82. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  83. if accuracy >= threshold:
  84. break
  85. assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2
  86. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  87. @pytest.mark.forked
  88. def test_decentralized_optimizer_step():
  89. dht_root = DHT(start=True)
  90. initial_peers = dht_root.get_visible_maddrs()
  91. param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
  92. opt1 = DecentralizedSGD([param1], lr=0.1, dht=DHT(initial_peers=initial_peers, start=True),
  93. prefix='foo', target_group_size=2, verbose=True)
  94. param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
  95. opt2 = DecentralizedSGD([param2], lr=0.05, dht=DHT(initial_peers=initial_peers, start=True),
  96. prefix='foo', target_group_size=2, verbose=True)
  97. assert not torch.allclose(param1, param2)
  98. (param1.sum() + 300 * param2.sum()).backward()
  99. opt1.step()
  100. opt2.step()
  101. time.sleep(0.5)
  102. assert torch.allclose(param1, param2)
  103. reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300)
  104. assert torch.allclose(param1, torch.full_like(param1, reference))
  105. @pytest.mark.forked
  106. def test_decentralized_optimizer_averaging():
  107. dht_root = DHT(start=True)
  108. initial_peers = dht_root.get_visible_maddrs()
  109. param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
  110. opt1 = DecentralizedAdam([param1], lr=0.1, averaging_steps_period=1, dht=DHT(initial_peers=initial_peers, start=True),
  111. prefix='foo', target_group_size=2, verbose=True)
  112. param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
  113. opt2 = DecentralizedAdam([param2], lr=0.05, averaging_steps_period=1, dht=DHT(initial_peers=initial_peers, start=True),
  114. prefix='foo', target_group_size=2, verbose=True)
  115. assert not torch.allclose(param1, param2)
  116. (param1.sum() + param2.sum()).backward()
  117. opt1.step()
  118. opt2.step()
  119. time.sleep(0.5)
  120. assert torch.allclose(param1, param2)
  121. assert torch.allclose(opt1.state[param1]["exp_avg_sq"], opt2.state[param2]["exp_avg_sq"])