common.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import time
  2. import torch
  3. from torch import nn as nn
  4. from hivemind.moe.server.layers.custom_experts import register_expert_class
  5. # https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py
  6. @torch.jit.script
  7. def gelu_fast(x):
  8. return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
  9. ffn_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))
  10. @register_expert_class("ffn", ffn_sample_input)
  11. class FeedforwardBlock(nn.Module):
  12. def __init__(self, hid_dim):
  13. super().__init__()
  14. self.ffn = nn.Linear(hid_dim, 4 * hid_dim)
  15. self.ffn_output = nn.Linear(4 * hid_dim, hid_dim)
  16. self.layer_norm = nn.LayerNorm(hid_dim, eps=1e-12)
  17. def forward(self, x):
  18. ffn_output = self.ffn(x)
  19. ffn_output = gelu_fast(ffn_output)
  20. ffn_output = self.ffn_output(ffn_output)
  21. return self.layer_norm(x + ffn_output)
  22. class TransformerEncoderLayer(nn.Module):
  23. """
  24. A slight modification of torch.nn.TransformerEncoderLayer which allows for torch.jit scripting
  25. """
  26. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
  27. super().__init__()
  28. self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
  29. # Implementation of Feedforward model
  30. self.linear1 = nn.Linear(d_model, dim_feedforward)
  31. self.dropout = nn.Dropout(dropout)
  32. self.linear2 = nn.Linear(dim_feedforward, d_model)
  33. self.norm1 = nn.LayerNorm(d_model)
  34. self.norm2 = nn.LayerNorm(d_model)
  35. self.dropout1 = nn.Dropout(dropout)
  36. self.dropout2 = nn.Dropout(dropout)
  37. self.activation = gelu_fast
  38. def forward(self, src, src_key_padding_mask=None):
  39. # (N, S, E) -> (S, N, E)
  40. src = src.transpose(0, 1)
  41. src2 = self.self_attn(src, src, src, key_padding_mask=src_key_padding_mask)[0]
  42. src = src + self.dropout1(src2)
  43. src = self.norm1(src)
  44. src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
  45. src = src + self.dropout2(src2)
  46. src = self.norm2(src)
  47. # (S, N, E) -> (N, S, E)
  48. src = src.transpose(0, 1)
  49. return src
  50. transformer_sample_input = lambda batch_size, hid_dim: (
  51. torch.empty((batch_size, 128, hid_dim)),
  52. torch.empty((batch_size, 128), dtype=torch.bool),
  53. )
  54. @register_expert_class("transformer", transformer_sample_input)
  55. class TunedTransformer(TransformerEncoderLayer):
  56. def __init__(self, hid_dim):
  57. super().__init__(hid_dim, dim_feedforward=4 * hid_dim, nhead=16)
  58. nop_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))
  59. @register_expert_class("nop", nop_sample_input)
  60. class NopExpert(nn.Sequential):
  61. def __init__(self, hid_dim):
  62. super().__init__()
  63. self.w = nn.Parameter(torch.zeros(0), requires_grad=True)
  64. def forward(self, x):
  65. return x.clone()
  66. @register_expert_class("nop_delay", nop_sample_input)
  67. class DelayedNopExpert(nn.Sequential):
  68. def __init__(self, hid_dim, delay=0.5):
  69. super().__init__()
  70. self.w = nn.Parameter(torch.zeros(0), requires_grad=True)
  71. self.delay = delay
  72. def forward(self, x):
  73. time.sleep(self.delay)
  74. return x.clone()