common.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import torch
  2. from torch import nn as nn
  3. class FeedforwardBlock(nn.Module):
  4. def __init__(self, hid_dim):
  5. super().__init__()
  6. self.layers = nn.Sequential(
  7. nn.Linear(hid_dim, 4 * hid_dim),
  8. nn.LayerNorm(4 * hid_dim),
  9. nn.ReLU(inplace=True),
  10. nn.Linear(4 * hid_dim, 4 * hid_dim),
  11. nn.LayerNorm(4 * hid_dim),
  12. nn.ReLU(inplace=True),
  13. nn.Linear(4 * hid_dim, hid_dim),
  14. )
  15. def forward(self, x):
  16. return x + self.layers(x)
  17. class TransformerEncoderLayer(nn.Module):
  18. """
  19. A slight modification of torch.nn.TransformerEncoderLayer which allows for torch.jit scripting
  20. """
  21. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
  22. super().__init__()
  23. self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
  24. # Implementation of Feedforward model
  25. self.linear1 = nn.Linear(d_model, dim_feedforward)
  26. self.dropout = nn.Dropout(dropout)
  27. self.linear2 = nn.Linear(dim_feedforward, d_model)
  28. self.norm1 = nn.LayerNorm(d_model)
  29. self.norm2 = nn.LayerNorm(d_model)
  30. self.dropout1 = nn.Dropout(dropout)
  31. self.dropout2 = nn.Dropout(dropout)
  32. self.activation = torch.nn.GELU()
  33. def forward(self, src):
  34. src.transpose_(0, 1)
  35. src2 = self.self_attn(src, src, src)[0]
  36. src = src + self.dropout1(src2)
  37. src = self.norm1(src)
  38. src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
  39. src = src + self.dropout2(src2)
  40. src = self.norm2(src)
  41. src.transpose_(0, 1)
  42. return src
  43. class NopExpert(nn.Sequential):
  44. def __init__(self, hid_dim):
  45. super().__init__()
  46. self.w = nn.Parameter(torch.zeros(0), requires_grad=True)
  47. def forward(self, x):
  48. return x.clone()