playgroud_example.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import hivemind
  2. from hivemind.optim.experimental.grad_averager import GradientAverager
  3. from hivemind.optim.experimental.power_ef_averager import PowerEFGradientAverager
  4. import faulthandler
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torchvision
  9. from torchvision.datasets import MNIST
  10. import multiprocessing as mp
  11. import threading
  12. import os
  13. import time
  14. print_step = 10
  15. class Peer(threading.Thread):
  16. def __init__(self, idx, *, start: bool):
  17. super().__init__(daemon=True)
  18. self.dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
  19. self.model = SmallCNN()
  20. for param in self.model.parameters():
  21. param.grad = torch.zeros_like(param).share_memory_()
  22. self.averager = PowerEFGradientAverager(
  23. self.model.parameters(), 1, dht=self.dht, target_group_size=4, prefix='my_mega_exp', start=True,
  24. )
  25. if start:
  26. self.start()
  27. self.idx = idx
  28. def run(self):
  29. torch.manual_seed(self.idx)
  30. print('started', self.dht.peer_id)
  31. transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
  32. train_data = MNIST(f".", download=True, transform=transform)
  33. def data():
  34. while True:
  35. train_dataloader = torch.utils.data.DataLoader(train_data, num_workers=0, batch_size=1024, shuffle=True)
  36. for batch in train_dataloader:
  37. yield batch
  38. opt = torch.optim.Adam(self.model.parameters(), lr=0.001)
  39. next_step_time = hivemind.get_dht_time() + 5
  40. next_step_control = None
  41. for i, (xb, yb) in enumerate(data()):
  42. logits = self.model(xb)
  43. loss = F.cross_entropy(logits, yb)
  44. loss.backward()
  45. torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
  46. if next_step_control is None and (next_step_time - hivemind.get_dht_time() <= 1):
  47. next_step_control = self.averager.schedule_step(scheduled_time=next_step_time)
  48. self.averager.accumulate_grads_(batch_size=1024)
  49. if hivemind.get_dht_time() >= next_step_time:
  50. self.averager.step(control=next_step_control)
  51. next_step_control.result()
  52. with self.averager.use_averaged_gradients():
  53. with torch.no_grad():
  54. param = next(iter(self.model.parameters()))
  55. grad = param.grad.detach().cpu().norm().item()
  56. print_param = param.flatten()[-3:].detach().cpu().numpy()
  57. print(i, self.dht.peer_id.pretty()[-3:],f"{loss.item():.3f}", f"{hivemind.get_dht_time():.3f}", print_param, grad)
  58. opt.step()
  59. self.averager.reset_accumulated_grads_()
  60. next_step_time = hivemind.get_dht_time() + 5
  61. next_step_control = None
  62. if i > 10000: break
  63. class SmallCNN(nn.Module):
  64. def __init__(self):
  65. super().__init__()
  66. self.features = nn.Sequential(
  67. nn.Conv2d(1, 4, (5, 5)),
  68. nn.ReLU(),
  69. nn.Conv2d(4, 16, (5, 5)),
  70. nn.ReLU(),
  71. nn.Conv2d(16, 64, (5, 5)),
  72. nn.ReLU(),
  73. nn.Conv2d(64, 64, (5, 5)),
  74. nn.ReLU(),
  75. nn.MaxPool2d(2)
  76. )
  77. self.cls = nn.Sequential(
  78. nn.Linear(64 * 6 * 6, 400),
  79. nn.ReLU(),
  80. nn.Linear(400, 10)
  81. )
  82. def forward(self, x):
  83. feature = self.features(x)
  84. return self.cls(feature.view(x.size(0), -1))
  85. if __name__ == "__main__":
  86. dht_root = hivemind.DHT(start=True)
  87. peers = [
  88. Peer(0, start=False), Peer(1, start=False),
  89. Peer(2, start=False), Peer(3, start=False)
  90. ]
  91. peers[1].model.load_state_dict(peers[0].model.state_dict())
  92. peers[2].model.load_state_dict(peers[0].model.state_dict())
  93. peers[3].model.load_state_dict(peers[0].model.state_dict())
  94. for peer in peers:
  95. peer.start()
  96. for p in peers:
  97. p.join()