Explorar o código

Replace FeedforwardBlock with a correct implementation (#211)

* Replace FeedforwardBlock with a correct implementation

* Reduce number of classes in test_training
Max Ryabinin %!s(int64=4) %!d(string=hai) anos
pai
achega
ca6d87a837
Modificáronse 3 ficheiros con 17 adicións e 14 borrados
  1. 14 11
      hivemind/server/layers/common.py
  2. 1 1
      tests/test_moe.py
  3. 2 2
      tests/test_training.py

+ 14 - 11
hivemind/server/layers/common.py

@@ -2,21 +2,24 @@ import torch
 from torch import nn as nn
 
 
+# https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py
+@torch.jit.script
+def gelu_fast(x):
+    return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
+
+
 class FeedforwardBlock(nn.Module):
     def __init__(self, hid_dim):
         super().__init__()
-        self.layers = nn.Sequential(
-            nn.Linear(hid_dim, 4 * hid_dim),
-            nn.LayerNorm(4 * hid_dim),
-            nn.ReLU(inplace=True),
-            nn.Linear(4 * hid_dim, 4 * hid_dim),
-            nn.LayerNorm(4 * hid_dim),
-            nn.ReLU(inplace=True),
-            nn.Linear(4 * hid_dim, hid_dim),
-        )
+        self.ffn = nn.Linear(hid_dim, 4 * hid_dim)
+        self.ffn_output = nn.Linear(4 * hid_dim, hid_dim)
+        self.layer_norm = nn.LayerNorm(hid_dim, eps=1e-12)
 
     def forward(self, x):
-        return x + self.layers(x)
+        ffn_output = self.ffn(x)
+        ffn_output = gelu_fast(ffn_output)
+        ffn_output = self.ffn_output(ffn_output)
+        return self.layer_norm(x + ffn_output)
 
 
 class TransformerEncoderLayer(nn.Module):
@@ -37,7 +40,7 @@ class TransformerEncoderLayer(nn.Module):
         self.dropout1 = nn.Dropout(dropout)
         self.dropout2 = nn.Dropout(dropout)
 
-        self.activation = torch.nn.GELU()
+        self.activation = gelu_fast
 
     def forward(self, src, src_key_padding_mask=None):
         # (N, S, E) -> (S, N, E)

+ 1 - 1
tests/test_moe.py

@@ -188,7 +188,7 @@ def test_client_anomaly_detection():
                                                         max_batch_size=16,
                                                         )
 
-    experts['expert.3'].expert.layers[0].weight.data[0, 0] = float('nan')
+    experts['expert.3'].expert.ffn.weight.data[0, 0] = float('nan')
 
     dht = hivemind.DHT(start=True, expiration=999)
     server = hivemind.Server(dht, experts, num_connection_handlers=1)

+ 2 - 2
tests/test_training.py

@@ -12,7 +12,7 @@ from hivemind import RemoteExpert, background_server, DHT, DecentralizedSGD
 
 @pytest.mark.forked
 def test_training(max_steps: int = 100, threshold: float = 0.9):
-    dataset = load_digits()
+    dataset = load_digits(n_class=2)
     X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
     SGD = partial(torch.optim.SGD, lr=0.05)
 
@@ -20,7 +20,7 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
                            no_dht=True) as (server_endpoint, dht_endpoint):
         expert1 = RemoteExpert('expert.0', server_endpoint)
         expert2 = RemoteExpert('expert.1', server_endpoint)
-        model = nn.Sequential(expert2, nn.Tanh(), expert1, nn.Linear(64, 10))
+        model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
 
         opt = torch.optim.SGD(model.parameters(), lr=0.05)