custom_networks.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from hivemind.moe import register_expert_class
  5. sample_input = lambda batch_size, hidden_dim: torch.empty((batch_size, hidden_dim))
  6. @register_expert_class("perceptron", sample_input)
  7. class MultilayerPerceptron(nn.Module):
  8. def __init__(self, hidden_dim, num_classes=10):
  9. super().__init__()
  10. self.layer1 = nn.Linear(hidden_dim, 2 * hidden_dim)
  11. self.layer2 = nn.Linear(2 * hidden_dim, 2 * hidden_dim)
  12. self.layer3 = nn.Linear(2 * hidden_dim, num_classes)
  13. def forward(self, x):
  14. x = F.relu(self.layer1(x))
  15. x = F.relu(self.layer2(x))
  16. x = self.layer3(x)
  17. return x
  18. multihead_sample_input = lambda batch_size, hidden_dim: (
  19. torch.empty((batch_size, hidden_dim)),
  20. torch.empty((batch_size, 2 * hidden_dim)),
  21. torch.empty((batch_size, 3 * hidden_dim)),
  22. )
  23. @register_expert_class("multihead", multihead_sample_input)
  24. class MultiheadNetwork(nn.Module):
  25. def __init__(self, hidden_dim, num_classes=10):
  26. super().__init__()
  27. self.layer1 = nn.Linear(hidden_dim, num_classes)
  28. self.layer2 = nn.Linear(2 * hidden_dim, num_classes)
  29. self.layer3 = nn.Linear(3 * hidden_dim, num_classes)
  30. def forward(self, x1, x2, x3):
  31. x = self.layer1(x1) + self.layer2(x2) + self.layer3(x3)
  32. return x