your_code_here.py 691 B

123456789101112131415161718
  1. import torch
  2. import torch.nn as nn
  3. from hivemind.moe.server.layers.custom_experts import register_expert_class
  4. @register_expert_class("ExampleModule", lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)))
  5. class ExampleModule(nn.Module):
  6. def __init__(self, hid_dim):
  7. super().__init__()
  8. self.ffn = nn.Linear(hid_dim, 4 * hid_dim)
  9. self.ffn_output = nn.Linear(4 * hid_dim, hid_dim)
  10. self.layer_norm = nn.LayerNorm(hid_dim, eps=1e-12)
  11. def forward(self, x):
  12. ffn_output = self.ffn(x)
  13. ffn_output = torch.nn.functional.gelu(ffn_output)
  14. ffn_output = self.ffn_output(ffn_output)
  15. return self.layer_norm(x + ffn_output)