example.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import time
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torchvision import datasets, transforms
  6. from tqdm.auto import tqdm
  7. import hivemind
  8. class SmallCNN(nn.Module):
  9. def __init__(self):
  10. super().__init__()
  11. self.features = nn.Sequential(
  12. nn.Conv2d(1, 16, (9, 9)),
  13. nn.ReLU(),
  14. nn.Conv2d(16, 16, (9, 9)),
  15. nn.ReLU(),
  16. nn.MaxPool2d(2)
  17. )
  18. self.cls = nn.Sequential(
  19. nn.Linear(16 * 6 * 6, 400),
  20. nn.ReLU(),
  21. nn.Linear(400, 10)
  22. )
  23. def forward(self, x):
  24. feature = self.features(x)
  25. return self.cls(feature.view(x.size(0), -1))
  26. if __name__ == "__main__":
  27. # Create dataset and model, same as in the basic tutorial
  28. # For this basic tutorial, we download only the training set
  29. transform = transforms.Compose([transforms.ToTensor()])
  30. trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  31. model = SmallCNN()
  32. opt = torch.optim.Adam(model.parameters(), lr=0.001)
  33. # Create DHT: a decentralized key-value storage shared between peers
  34. dht = hivemind.DHT(start=True, initial_peers=["/ip4/127.0.0.1/tcp/36805/p2p/Qmc7nJt6Pc3Eii4X1ZqtkxbiRWvf97nNfuD4CJpAep5THU"])
  35. print("To join the training, use initial_peers =", [str(addr) for addr in dht.get_visible_maddrs()])
  36. # Set up a decentralized optimizer that will average with peers in background
  37. opt = hivemind.Optimizer(
  38. dht=dht, # use a DHT that is connected with other peers
  39. run_id='my_cifar_run', # unique identifier of this collaborative run
  40. batch_size_per_step=16, # each call to opt.step adds this many samples towards the next epoch
  41. target_batch_size=1000, # after peers collectively process this many samples, average weights and begin the next epoch
  42. optimizer=opt, # wrap the SGD optimizer defined above
  43. use_local_updates=False, # perform optimizer steps with averaged gradients
  44. matchmaking_time=3.0, # when averaging parameters, gather peers in background for up to this many seconds
  45. averaging_timeout=10.0, # give up on averaging if not successful in this many seconds
  46. verbose=True, # print logs incessently
  47. grad_rank_averager="power_sgd",
  48. grad_averager_opts={"averager_rank": 1}
  49. )
  50. opt.load_state_from_peers()
  51. # Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created
  52. with tqdm() as progressbar:
  53. while True:
  54. for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=16):
  55. time.sleep(0.1)
  56. opt.zero_grad()
  57. loss = F.cross_entropy(model(x_batch), y_batch)
  58. loss.backward()
  59. torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
  60. opt.step()
  61. progressbar.desc = f"loss = {loss.item():.3f}"
  62. progressbar.update()