test_training.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import time
  2. from functools import partial
  3. import pytest
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from sklearn.datasets import load_digits
  8. from hivemind import DHT
  9. from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
  10. from hivemind.moe.client.expert import create_remote_experts
  11. from hivemind.moe.expert_uid import ExpertInfo
  12. from hivemind.moe.server import background_server
  13. from hivemind.optim import DecentralizedAdam, DecentralizedSGD
  14. @pytest.mark.forked
  15. def test_training(max_steps: int = 100, threshold: float = 0.9):
  16. dataset = load_digits(n_class=2)
  17. X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
  18. SGD = partial(torch.optim.SGD, lr=0.05)
  19. with background_server(
  20. num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
  21. ) as server_peer_info:
  22. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  23. expert1, expert2 = create_remote_experts(
  24. [
  25. ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
  26. ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id),
  27. ],
  28. dht=dht,
  29. )
  30. model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
  31. opt = SGD(model.parameters(), lr=0.05)
  32. for step in range(max_steps):
  33. outputs = model(X_train)
  34. loss = F.cross_entropy(outputs, y_train)
  35. loss.backward()
  36. opt.step()
  37. opt.zero_grad()
  38. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  39. if accuracy >= threshold:
  40. break
  41. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  42. @pytest.mark.forked
  43. def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=2):
  44. dataset = load_digits(n_class=2)
  45. X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
  46. subsample_ix = torch.randint(0, len(X_train), (32,))
  47. X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
  48. SGD = partial(torch.optim.SGD, lr=0.05)
  49. all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
  50. with background_server(
  51. expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
  52. ) as server_peer_info:
  53. dht = DHT(start=True, initial_peers=server_peer_info.addrs)
  54. moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix="expert.", k_best=2)
  55. model = nn.Sequential(moe, nn.Linear(64, 2))
  56. opt = SGD(model.parameters(), lr=0.05)
  57. for step in range(max_steps):
  58. outputs = model(X_train)
  59. loss = F.cross_entropy(outputs, y_train)
  60. loss.backward()
  61. opt.step()
  62. opt.zero_grad()
  63. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  64. if accuracy >= threshold:
  65. break
  66. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  67. class SwitchNetwork(nn.Module):
  68. def __init__(self, dht, in_features, num_classes, num_experts):
  69. super().__init__()
  70. self.moe = RemoteSwitchMixtureOfExperts(
  71. in_features=in_features,
  72. grid_size=(num_experts,),
  73. dht=dht,
  74. jitter_eps=0,
  75. uid_prefix="expert.",
  76. k_best=1,
  77. k_min=1,
  78. )
  79. self.linear = nn.Linear(in_features, num_classes)
  80. def forward(self, x):
  81. moe_output, balancing_loss = self.moe(x)
  82. return self.linear(moe_output), balancing_loss
  83. @pytest.mark.forked
  84. def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_experts=5):
  85. dataset = load_digits(n_class=2)
  86. X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
  87. subsample_ix = torch.randint(0, len(X_train), (32,))
  88. X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
  89. SGD = partial(torch.optim.SGD, lr=0.05)
  90. all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
  91. with background_server(
  92. expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
  93. ) as server_peer_info:
  94. dht = DHT(start=True, initial_peers=server_peer_info.addrs)
  95. model = SwitchNetwork(dht, 64, 2, num_experts)
  96. opt = SGD(model.parameters(), lr=0.05)
  97. for step in range(max_steps):
  98. outputs, balancing_loss = model(X_train)
  99. loss = F.cross_entropy(outputs, y_train) + 0.01 * balancing_loss
  100. loss.backward()
  101. opt.step()
  102. opt.zero_grad()
  103. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  104. if accuracy >= threshold:
  105. break
  106. assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2
  107. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  108. @pytest.mark.forked
  109. def test_decentralized_optimizer_step():
  110. dht_root = DHT(start=True)
  111. initial_peers = dht_root.get_visible_maddrs()
  112. param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
  113. opt1 = DecentralizedSGD(
  114. [param1],
  115. lr=0.1,
  116. dht=DHT(initial_peers=initial_peers, start=True),
  117. prefix="foo",
  118. target_group_size=2,
  119. verbose=True,
  120. )
  121. param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
  122. opt2 = DecentralizedSGD(
  123. [param2],
  124. lr=0.05,
  125. dht=DHT(initial_peers=initial_peers, start=True),
  126. prefix="foo",
  127. target_group_size=2,
  128. verbose=True,
  129. )
  130. assert not torch.allclose(param1, param2)
  131. (param1.sum() + 300 * param2.sum()).backward()
  132. for i in range(5):
  133. time.sleep(0.1)
  134. opt1.step()
  135. opt2.step()
  136. opt1.zero_grad()
  137. opt2.zero_grad()
  138. assert torch.allclose(param1, param2)
  139. reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300)
  140. assert torch.allclose(param1, torch.full_like(param1, reference))
  141. @pytest.mark.skip(reason="Skipped until a more stable averager implementation is ready (TODO @justheuristic)")
  142. @pytest.mark.forked
  143. def test_decentralized_optimizer_averaging():
  144. dht_root = DHT(start=True)
  145. initial_peers = dht_root.get_visible_maddrs()
  146. param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
  147. opt1 = DecentralizedAdam(
  148. [param1],
  149. lr=0.1,
  150. averaging_steps_period=1,
  151. dht=DHT(initial_peers=initial_peers, start=True),
  152. prefix="foo",
  153. target_group_size=2,
  154. verbose=True,
  155. )
  156. param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
  157. opt2 = DecentralizedAdam(
  158. [param2],
  159. lr=0.05,
  160. averaging_steps_period=1,
  161. dht=DHT(initial_peers=initial_peers, start=True),
  162. prefix="foo",
  163. target_group_size=2,
  164. verbose=True,
  165. )
  166. assert not torch.allclose(param1, param2, atol=1e-3, rtol=0)
  167. (param1.sum() + param2.sum()).backward()
  168. for _ in range(100):
  169. time.sleep(0.1)
  170. opt1.step()
  171. opt2.step()
  172. opt1.zero_grad()
  173. opt2.zero_grad()
  174. assert torch.allclose(param1, param2, atol=1e-3, rtol=0)
  175. assert torch.allclose(opt1.state[param1]["exp_avg_sq"], opt2.state[param2]["exp_avg_sq"], atol=1e-3, rtol=0)