|
@@ -2,21 +2,24 @@ import torch
|
|
|
from torch import nn as nn
|
|
|
|
|
|
|
|
|
+# https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py
|
|
|
+@torch.jit.script
|
|
|
+def gelu_fast(x):
|
|
|
+ return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
|
|
|
+
|
|
|
+
|
|
|
class FeedforwardBlock(nn.Module):
|
|
|
def __init__(self, hid_dim):
|
|
|
super().__init__()
|
|
|
- self.layers = nn.Sequential(
|
|
|
- nn.Linear(hid_dim, 4 * hid_dim),
|
|
|
- nn.LayerNorm(4 * hid_dim),
|
|
|
- nn.ReLU(inplace=True),
|
|
|
- nn.Linear(4 * hid_dim, 4 * hid_dim),
|
|
|
- nn.LayerNorm(4 * hid_dim),
|
|
|
- nn.ReLU(inplace=True),
|
|
|
- nn.Linear(4 * hid_dim, hid_dim),
|
|
|
- )
|
|
|
+ self.ffn = nn.Linear(hid_dim, 4 * hid_dim)
|
|
|
+ self.ffn_output = nn.Linear(4 * hid_dim, hid_dim)
|
|
|
+ self.layer_norm = nn.LayerNorm(hid_dim, eps=1e-12)
|
|
|
|
|
|
def forward(self, x):
|
|
|
- return x + self.layers(x)
|
|
|
+ ffn_output = self.ffn(x)
|
|
|
+ ffn_output = gelu_fast(ffn_output)
|
|
|
+ ffn_output = self.ffn_output(ffn_output)
|
|
|
+ return self.layer_norm(x + ffn_output)
|
|
|
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Module):
|
|
@@ -37,7 +40,7 @@ class TransformerEncoderLayer(nn.Module):
|
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
|
|
|
|
- self.activation = torch.nn.GELU()
|
|
|
+ self.activation = gelu_fast
|
|
|
|
|
|
def forward(self, src, src_key_padding_mask=None):
|
|
|
# (N, S, E) -> (S, N, E)
|