test_training.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import time
  2. from functools import partial
  3. import pytest
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from sklearn.datasets import load_digits
  8. from hivemind import DHT
  9. from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
  10. from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
  11. from hivemind.moe.server import background_server
  12. from hivemind.optim import DecentralizedAdam, DecentralizedSGD
  13. @pytest.mark.forked
  14. def test_training(max_steps: int = 100, threshold: float = 0.9):
  15. dataset = load_digits(n_class=2)
  16. X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
  17. SGD = partial(torch.optim.SGD, lr=0.05)
  18. with background_server(
  19. num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
  20. ) as server_peer_info:
  21. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  22. expert1, expert2 = create_remote_experts(
  23. [
  24. RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
  25. RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
  26. ],
  27. dht=dht,
  28. )
  29. model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
  30. opt = SGD(model.parameters(), lr=0.05)
  31. for step in range(max_steps):
  32. outputs = model(X_train)
  33. loss = F.cross_entropy(outputs, y_train)
  34. loss.backward()
  35. opt.step()
  36. opt.zero_grad()
  37. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  38. if accuracy >= threshold:
  39. break
  40. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  41. @pytest.mark.forked
  42. def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=2):
  43. dataset = load_digits(n_class=2)
  44. X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
  45. subsample_ix = torch.randint(0, len(X_train), (32,))
  46. X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
  47. SGD = partial(torch.optim.SGD, lr=0.05)
  48. all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
  49. with background_server(
  50. expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
  51. ) as server_peer_info:
  52. dht = DHT(start=True, initial_peers=server_peer_info.addrs)
  53. moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix="expert.", k_best=2)
  54. model = nn.Sequential(moe, nn.Linear(64, 2))
  55. opt = SGD(model.parameters(), lr=0.05)
  56. for step in range(max_steps):
  57. outputs = model(X_train)
  58. loss = F.cross_entropy(outputs, y_train)
  59. loss.backward()
  60. opt.step()
  61. opt.zero_grad()
  62. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  63. if accuracy >= threshold:
  64. break
  65. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  66. class SwitchNetwork(nn.Module):
  67. def __init__(self, dht, in_features, num_classes, num_experts):
  68. super().__init__()
  69. self.moe = RemoteSwitchMixtureOfExperts(
  70. in_features=in_features,
  71. grid_size=(num_experts,),
  72. dht=dht,
  73. jitter_eps=0,
  74. uid_prefix="expert.",
  75. k_best=1,
  76. k_min=1,
  77. )
  78. self.linear = nn.Linear(in_features, num_classes)
  79. def forward(self, x):
  80. moe_output, balancing_loss = self.moe(x)
  81. return self.linear(moe_output), balancing_loss
  82. @pytest.mark.forked
  83. def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_experts=5):
  84. dataset = load_digits(n_class=2)
  85. X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
  86. subsample_ix = torch.randint(0, len(X_train), (32,))
  87. X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
  88. SGD = partial(torch.optim.SGD, lr=0.05)
  89. all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
  90. with background_server(
  91. expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
  92. ) as server_peer_info:
  93. dht = DHT(start=True, initial_peers=server_peer_info.addrs)
  94. model = SwitchNetwork(dht, 64, 2, num_experts)
  95. opt = SGD(model.parameters(), lr=0.05)
  96. for step in range(max_steps):
  97. outputs, balancing_loss = model(X_train)
  98. loss = F.cross_entropy(outputs, y_train) + 0.01 * balancing_loss
  99. loss.backward()
  100. opt.step()
  101. opt.zero_grad()
  102. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  103. if accuracy >= threshold:
  104. break
  105. assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2
  106. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  107. @pytest.mark.forked
  108. def test_decentralized_optimizer_step():
  109. dht_root = DHT(start=True)
  110. initial_peers = dht_root.get_visible_maddrs()
  111. param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
  112. opt1 = DecentralizedSGD(
  113. [param1],
  114. lr=0.1,
  115. dht=DHT(initial_peers=initial_peers, start=True),
  116. prefix="foo",
  117. target_group_size=2,
  118. verbose=True,
  119. )
  120. param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
  121. opt2 = DecentralizedSGD(
  122. [param2],
  123. lr=0.05,
  124. dht=DHT(initial_peers=initial_peers, start=True),
  125. prefix="foo",
  126. target_group_size=2,
  127. verbose=True,
  128. )
  129. assert not torch.allclose(param1, param2)
  130. (param1.sum() + 300 * param2.sum()).backward()
  131. for i in range(5):
  132. time.sleep(0.1)
  133. opt1.step()
  134. opt2.step()
  135. opt1.zero_grad()
  136. opt2.zero_grad()
  137. assert torch.allclose(param1, param2)
  138. reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300)
  139. assert torch.allclose(param1, torch.full_like(param1, reference))
  140. @pytest.mark.skip(reason="Skipped until a more stable averager implementation is ready (TODO @justheuristic)")
  141. @pytest.mark.forked
  142. def test_decentralized_optimizer_averaging():
  143. dht_root = DHT(start=True)
  144. initial_peers = dht_root.get_visible_maddrs()
  145. param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
  146. opt1 = DecentralizedAdam(
  147. [param1],
  148. lr=0.1,
  149. averaging_steps_period=1,
  150. dht=DHT(initial_peers=initial_peers, start=True),
  151. prefix="foo",
  152. target_group_size=2,
  153. verbose=True,
  154. )
  155. param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
  156. opt2 = DecentralizedAdam(
  157. [param2],
  158. lr=0.05,
  159. averaging_steps_period=1,
  160. dht=DHT(initial_peers=initial_peers, start=True),
  161. prefix="foo",
  162. target_group_size=2,
  163. verbose=True,
  164. )
  165. assert not torch.allclose(param1, param2, atol=1e-3, rtol=0)
  166. (param1.sum() + param2.sum()).backward()
  167. for _ in range(100):
  168. time.sleep(0.1)
  169. opt1.step()
  170. opt2.step()
  171. opt1.zero_grad()
  172. opt2.zero_grad()
  173. assert torch.allclose(param1, param2, atol=1e-3, rtol=0)
  174. assert torch.allclose(opt1.state[param1]["exp_avg_sq"], opt2.state[param2]["exp_avg_sq"], atol=1e-3, rtol=0)