custom_networks.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from hivemind.server.layers.custom_experts import register_expert_class
  5. sample_input = lambda batch_size, hidden_dim : torch.empty((batch_size, hidden_dim))
  6. @register_expert_class('perceptron', sample_input)
  7. class MultilayerPerceptron(nn.Module):
  8. def __init__(self, hidden_dim, num_classes=10):
  9. super(MultilayerPerceptron, self).__init__()
  10. self.layer1 = nn.Linear(hidden_dim, 2 * hidden_dim)
  11. self.layer2 = nn.Linear(2 * hidden_dim, 2 * hidden_dim)
  12. self.layer3 = nn.Linear(2 * hidden_dim, num_classes)
  13. def forward(self, x):
  14. x = F.relu(self.layer1(x))
  15. x = F.relu(self.layer2(x))
  16. x = self.layer3(x)
  17. return x
  18. multihead_sample_input = lambda batch_size, hidden_dim : \
  19. (torch.empty((batch_size, hidden_dim)),
  20. torch.empty((batch_size, 2 * hidden_dim)),
  21. torch.empty((batch_size, 3 * hidden_dim)),)
  22. @register_expert_class('multihead', multihead_sample_input)
  23. class MultiheadNetwork(nn.Module):
  24. def __init__(self, hidden_dim, num_classes=10):
  25. super(MultiheadNetwork, self).__init__()
  26. self.layer1 = nn.Linear(hidden_dim, num_classes)
  27. self.layer2 = nn.Linear(2 * hidden_dim, num_classes)
  28. self.layer3 = nn.Linear(3 * hidden_dim, num_classes)
  29. def forward(self, x1, x2, x3):
  30. x = self.layer1(x1) + self.layer2(x2) + self.layer3(x3)
  31. return x