layers.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import torch
  2. import torch.nn as nn
  3. from hivemind.utils.custom_layers import DeterministicDropout
  4. class FeedforwardBlock(nn.Module):
  5. def __init__(self, hid_dim):
  6. super().__init__()
  7. self.layers = nn.Sequential(
  8. nn.Linear(hid_dim, 4 * hid_dim),
  9. nn.LayerNorm(4 * hid_dim),
  10. nn.ReLU(inplace=True),
  11. nn.Linear(4 * hid_dim, 4 * hid_dim),
  12. nn.LayerNorm(4 * hid_dim),
  13. nn.ReLU(inplace=True),
  14. nn.Linear(4 * hid_dim, hid_dim),
  15. )
  16. def forward(self, x):
  17. return x + self.layers(x)
  18. class TransformerEncoderLayer(nn.Module):
  19. """
  20. A slight modification of torch.nn.TransformerEncoderLayer which allows for torch.jit scripting
  21. """
  22. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
  23. super().__init__()
  24. self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
  25. # Implementation of Feedforward model
  26. self.linear1 = nn.Linear(d_model, dim_feedforward)
  27. self.dropout = nn.Dropout(dropout)
  28. self.linear2 = nn.Linear(dim_feedforward, d_model)
  29. self.norm1 = nn.LayerNorm(d_model)
  30. self.norm2 = nn.LayerNorm(d_model)
  31. self.dropout1 = nn.Dropout(dropout)
  32. self.dropout2 = nn.Dropout(dropout)
  33. self.activation = torch.nn.GELU()
  34. def forward(self, src):
  35. src.transpose_(0, 1)
  36. src2 = self.self_attn(src, src, src)[0]
  37. src = src + self.dropout1(src2)
  38. src = self.norm1(src)
  39. src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
  40. src = src + self.dropout2(src2)
  41. src = self.norm2(src)
  42. src.transpose_(0, 1)
  43. return src
  44. class NopExpert(nn.Sequential):
  45. def __init__(self, hid_dim):
  46. super().__init__()
  47. self.w = nn.Parameter(torch.zeros(0), requires_grad=True)
  48. def forward(self, x):
  49. return x.clone()
  50. class DeterministicDropoutNetwork(nn.Module):
  51. def __init__(self, hid_dim, dropout_prob):
  52. super().__init__()
  53. self.linear_in = nn.Linear(hid_dim, 2 * hid_dim)
  54. self.activation = nn.ReLU()
  55. self.dropout = DeterministicDropout(dropout_prob)
  56. self.linear_out = nn.Linear(2 * hid_dim, hid_dim)
  57. def forward(self, x, mask):
  58. x = self.linear_in(self.dropout(x, mask))
  59. return self.linear_out(self.activation(x))
  60. name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim),
  61. 'transformer': lambda hid_dim: TransformerEncoderLayer(hid_dim, nhead=16),
  62. 'nop': lambda hid_dim: NopExpert(hid_dim),
  63. 'det_dropout': lambda hid_dim: DeterministicDropoutNetwork(hid_dim, dropout_prob=0.2)}
  64. name_to_input = {'ffn': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
  65. 'transformer': lambda batch_size, hid_dim: torch.empty((batch_size, 512, hid_dim)),
  66. 'nop': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
  67. 'det_dropout': lambda batch_size, hid_dim:
  68. (torch.empty((batch_size, hid_dim)), torch.randint(0, 1, (batch_size, hid_dim)))}