test_training.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import time
  2. from functools import partial
  3. import pytest
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from sklearn.datasets import load_digits
  8. from hivemind import DHT
  9. from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
  10. from hivemind.moe.server import background_server
  11. from hivemind.optim import DecentralizedAdam, DecentralizedSGD
  12. @pytest.mark.forked
  13. def test_training(max_steps: int = 100, threshold: float = 0.9):
  14. dataset = load_digits(n_class=2)
  15. X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
  16. SGD = partial(torch.optim.SGD, lr=0.05)
  17. with background_server(num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1, no_dht=True) as (
  18. server_endpoint,
  19. _,
  20. ):
  21. expert1 = RemoteExpert("expert.0", server_endpoint)
  22. expert2 = RemoteExpert("expert.1", server_endpoint)
  23. model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
  24. opt = SGD(model.parameters(), lr=0.05)
  25. for step in range(max_steps):
  26. outputs = model(X_train)
  27. loss = F.cross_entropy(outputs, y_train)
  28. loss.backward()
  29. opt.step()
  30. opt.zero_grad()
  31. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  32. if accuracy >= threshold:
  33. break
  34. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  35. @pytest.mark.forked
  36. def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=2):
  37. dataset = load_digits(n_class=2)
  38. X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
  39. subsample_ix = torch.randint(0, len(X_train), (32,))
  40. X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
  41. SGD = partial(torch.optim.SGD, lr=0.05)
  42. all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
  43. with background_server(
  44. expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
  45. ) as (server_endpoint, dht_maddrs):
  46. dht = DHT(start=True, initial_peers=dht_maddrs)
  47. moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix="expert.", k_best=2)
  48. model = nn.Sequential(moe, nn.Linear(64, 2))
  49. opt = SGD(model.parameters(), lr=0.05)
  50. for step in range(max_steps):
  51. outputs = model(X_train)
  52. loss = F.cross_entropy(outputs, y_train)
  53. loss.backward()
  54. opt.step()
  55. opt.zero_grad()
  56. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  57. if accuracy >= threshold:
  58. break
  59. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  60. class SwitchNetwork(nn.Module):
  61. def __init__(self, dht, in_features, num_classes, num_experts):
  62. super().__init__()
  63. self.moe = RemoteSwitchMixtureOfExperts(
  64. in_features=in_features,
  65. grid_size=(num_experts,),
  66. dht=dht,
  67. jitter_eps=0,
  68. uid_prefix="expert.",
  69. k_best=1,
  70. k_min=1,
  71. )
  72. self.linear = nn.Linear(in_features, num_classes)
  73. def forward(self, x):
  74. moe_output, balancing_loss = self.moe(x)
  75. return self.linear(moe_output), balancing_loss
  76. @pytest.mark.forked
  77. def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_experts=5):
  78. dataset = load_digits(n_class=2)
  79. X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
  80. subsample_ix = torch.randint(0, len(X_train), (32,))
  81. X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
  82. SGD = partial(torch.optim.SGD, lr=0.05)
  83. all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
  84. with background_server(
  85. expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
  86. ) as (server_endpoint, dht_maddrs):
  87. dht = DHT(start=True, initial_peers=dht_maddrs)
  88. model = SwitchNetwork(dht, 64, 2, num_experts)
  89. opt = SGD(model.parameters(), lr=0.05)
  90. for step in range(max_steps):
  91. outputs, balancing_loss = model(X_train)
  92. loss = F.cross_entropy(outputs, y_train) + 0.01 * balancing_loss
  93. loss.backward()
  94. opt.step()
  95. opt.zero_grad()
  96. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  97. if accuracy >= threshold:
  98. break
  99. assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2
  100. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  101. @pytest.mark.forked
  102. def test_decentralized_optimizer_step():
  103. dht_root = DHT(start=True)
  104. initial_peers = dht_root.get_visible_maddrs()
  105. param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
  106. opt1 = DecentralizedSGD(
  107. [param1],
  108. lr=0.1,
  109. dht=DHT(initial_peers=initial_peers, start=True),
  110. prefix="foo",
  111. target_group_size=2,
  112. verbose=True,
  113. )
  114. param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
  115. opt2 = DecentralizedSGD(
  116. [param2],
  117. lr=0.05,
  118. dht=DHT(initial_peers=initial_peers, start=True),
  119. prefix="foo",
  120. target_group_size=2,
  121. verbose=True,
  122. )
  123. assert not torch.allclose(param1, param2)
  124. (param1.sum() + 300 * param2.sum()).backward()
  125. for i in range(5):
  126. time.sleep(0.1)
  127. opt1.step()
  128. opt2.step()
  129. opt1.zero_grad()
  130. opt2.zero_grad()
  131. assert torch.allclose(param1, param2)
  132. reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300)
  133. assert torch.allclose(param1, torch.full_like(param1, reference))
  134. @pytest.mark.skip(reason="Skipped until a more stable averager implementation is ready (TODO @justheuristic)")
  135. @pytest.mark.forked
  136. def test_decentralized_optimizer_averaging():
  137. dht_root = DHT(start=True)
  138. initial_peers = dht_root.get_visible_maddrs()
  139. param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
  140. opt1 = DecentralizedAdam(
  141. [param1],
  142. lr=0.1,
  143. averaging_steps_period=1,
  144. dht=DHT(initial_peers=initial_peers, start=True),
  145. prefix="foo",
  146. target_group_size=2,
  147. verbose=True,
  148. )
  149. param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
  150. opt2 = DecentralizedAdam(
  151. [param2],
  152. lr=0.05,
  153. averaging_steps_period=1,
  154. dht=DHT(initial_peers=initial_peers, start=True),
  155. prefix="foo",
  156. target_group_size=2,
  157. verbose=True,
  158. )
  159. assert not torch.allclose(param1, param2, atol=1e-3, rtol=0)
  160. (param1.sum() + param2.sum()).backward()
  161. for _ in range(100):
  162. time.sleep(0.1)
  163. opt1.step()
  164. opt2.step()
  165. opt1.zero_grad()
  166. opt2.zero_grad()
  167. assert torch.allclose(param1, param2, atol=1e-3, rtol=0)
  168. assert torch.allclose(opt1.state[param1]["exp_avg_sq"], opt2.state[param2]["exp_avg_sq"], atol=1e-3, rtol=0)