playgroud_example.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import hivemind
  2. from hivemind.optim.experimental.grad_averager import GradientAverager
  3. from hivemind.optim.experimental.power_ef_averager import PowerEFGradientAverager
  4. from hivemind.optim.experimental.power_sgd_averager import PowerSGDGradientAverager
  5. import faulthandler
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torchvision
  10. from torchvision.datasets import MNIST
  11. import multiprocessing as mp
  12. import threading
  13. import os
  14. import random
  15. import time
  16. print_step = 10
  17. class Peer(threading.Thread):
  18. def __init__(self, idx, *, start: bool):
  19. super().__init__(daemon=True)
  20. self.dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
  21. self.model = SmallCNN()
  22. for param in self.model.parameters():
  23. param.grad = torch.zeros_like(param).share_memory_()
  24. if start:
  25. self.start()
  26. self.idx = idx
  27. def run(self):
  28. torch.manual_seed(self.idx)
  29. print('started', self.dht.peer_id)
  30. transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
  31. train_data = MNIST(f".", download=True, transform=transform)
  32. def data():
  33. while True:
  34. train_dataloader = torch.utils.data.DataLoader(train_data, num_workers=0, batch_size=64, shuffle=True)
  35. for batch in train_dataloader:
  36. yield batch
  37. opt = hivemind.Optimizer(
  38. dht=self.dht,
  39. prefix="my_super_run",
  40. params=self.model.parameters(),
  41. optimizer=torch.optim.SGD,
  42. lr=0.1,
  43. train_batch_size=256,
  44. batch_size=64
  45. )
  46. opt.load_state_from_peers()
  47. for i, (xb, yb) in enumerate(data()):
  48. logits = self.model(xb)
  49. loss = F.cross_entropy(logits, yb)
  50. loss.backward()
  51. torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
  52. self.averager.accumulate_grads_(batch_size=64)
  53. opt.step()
  54. opt.zero_grad()
  55. if i > 100000: break
  56. class SmallCNN(nn.Module):
  57. def __init__(self):
  58. super().__init__()
  59. self.features = nn.Sequential(
  60. nn.Conv2d(1, 16, (9, 9)),
  61. nn.ReLU(),
  62. nn.Conv2d(16, 16, (9, 9)),
  63. nn.ReLU(),
  64. nn.MaxPool2d(2)
  65. )
  66. self.cls = nn.Sequential(
  67. nn.Linear(16 * 6 * 6, 400),
  68. nn.ReLU(),
  69. nn.Linear(400, 10)
  70. )
  71. def forward(self, x):
  72. feature = self.features(x)
  73. return self.cls(feature.view(x.size(0), -1))
  74. if __name__ == "__main__":
  75. dht_root = hivemind.DHT(start=True)
  76. peers = [
  77. Peer(i, start=False) for i in range(4)
  78. ]
  79. for i in range(1, 4):
  80. peers[i].model.load_state_dict(peers[0].model.state_dict())
  81. for peer in peers:
  82. peer.start()
  83. for p in peers:
  84. p.join()