test_training.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from functools import partial
  2. import pytest
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from sklearn.datasets import load_digits
  7. from hivemind import DHT
  8. from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
  9. from hivemind.moe.client.expert import create_remote_experts
  10. from hivemind.moe.expert_uid import ExpertInfo
  11. from hivemind.moe.server import background_server
  12. @pytest.mark.forked
  13. def test_training(max_steps: int = 100, threshold: float = 0.9):
  14. dataset = load_digits(n_class=2)
  15. X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
  16. SGD = partial(torch.optim.SGD, lr=0.05)
  17. with background_server(
  18. num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
  19. ) as server_peer_info:
  20. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  21. expert1, expert2 = create_remote_experts(
  22. [
  23. ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
  24. ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id),
  25. ],
  26. dht=dht,
  27. )
  28. model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
  29. opt = SGD(model.parameters(), lr=0.05)
  30. for step in range(max_steps):
  31. outputs = model(X_train)
  32. loss = F.cross_entropy(outputs, y_train)
  33. loss.backward()
  34. opt.step()
  35. opt.zero_grad()
  36. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  37. if accuracy >= threshold:
  38. break
  39. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  40. @pytest.mark.forked
  41. def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=2):
  42. dataset = load_digits(n_class=2)
  43. X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
  44. subsample_ix = torch.randint(0, len(X_train), (32,))
  45. X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
  46. SGD = partial(torch.optim.SGD, lr=0.05)
  47. all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
  48. with background_server(
  49. expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
  50. ) as server_peer_info:
  51. dht = DHT(start=True, initial_peers=server_peer_info.addrs)
  52. moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix="expert.", k_best=2)
  53. model = nn.Sequential(moe, nn.Linear(64, 2))
  54. opt = SGD(model.parameters(), lr=0.05)
  55. for step in range(max_steps):
  56. outputs = model(X_train)
  57. loss = F.cross_entropy(outputs, y_train)
  58. loss.backward()
  59. opt.step()
  60. opt.zero_grad()
  61. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  62. if accuracy >= threshold:
  63. break
  64. assert accuracy >= threshold, f"too small accuracy: {accuracy}"
  65. class SwitchNetwork(nn.Module):
  66. def __init__(self, dht, in_features, num_classes, num_experts):
  67. super().__init__()
  68. self.moe = RemoteSwitchMixtureOfExperts(
  69. in_features=in_features,
  70. grid_size=(num_experts,),
  71. dht=dht,
  72. jitter_eps=0,
  73. uid_prefix="expert.",
  74. k_best=1,
  75. k_min=1,
  76. )
  77. self.linear = nn.Linear(in_features, num_classes)
  78. def forward(self, x):
  79. moe_output, balancing_loss = self.moe(x)
  80. return self.linear(moe_output), balancing_loss
  81. @pytest.mark.forked
  82. def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_experts=5):
  83. dataset = load_digits(n_class=2)
  84. X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
  85. subsample_ix = torch.randint(0, len(X_train), (32,))
  86. X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
  87. SGD = partial(torch.optim.SGD, lr=0.05)
  88. all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
  89. with background_server(
  90. expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
  91. ) as server_peer_info:
  92. dht = DHT(start=True, initial_peers=server_peer_info.addrs)
  93. model = SwitchNetwork(dht, 64, 2, num_experts)
  94. opt = SGD(model.parameters(), lr=0.05)
  95. for step in range(max_steps):
  96. outputs, balancing_loss = model(X_train)
  97. loss = F.cross_entropy(outputs, y_train) + 0.01 * balancing_loss
  98. loss.backward()
  99. opt.step()
  100. opt.zero_grad()
  101. accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
  102. if accuracy >= threshold:
  103. break
  104. assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2
  105. assert accuracy >= threshold, f"too small accuracy: {accuracy}"