test_ffn.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from lib.modules.ffn import LeanFFN
  5. class ReferenceFFN(nn.Module):
  6. def __init__(self,
  7. hidden_size: int,
  8. intermediate_size: int,
  9. activation=F.gelu,
  10. layer_norm_eps=1e-12,
  11. dropout: float = 0.0):
  12. super().__init__()
  13. self.dense_i2h = nn.Linear(hidden_size, intermediate_size)
  14. self.dense_h2o = nn.Linear(intermediate_size, hidden_size)
  15. self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
  16. self.activation = activation
  17. self.dropout = dropout
  18. def forward(self, input):
  19. output = self.dense_i2h(self.layer_norm(input))
  20. output = self.activation(output)
  21. output = self.dense_h2o(output)
  22. output = F.dropout(output, self.dropout)
  23. return output + input
  24. def test_ffn_exact_match():
  25. torch.use_deterministic_algorithms(True)
  26. batch_size = 4
  27. seq_len = 128
  28. dim = 32
  29. num_layers = 4
  30. baseline_ffn = ReferenceFFN(dim, 4 * dim)
  31. our_ffn = LeanFFN(dim, 4 * dim)
  32. assert our_ffn.load_state_dict(baseline_ffn.state_dict())
  33. x = torch.rand(batch_size, seq_len, dim, device='cpu', requires_grad=True)
  34. # test outputs
  35. out_ref = x
  36. for i in range(num_layers):
  37. out_ref = baseline_ffn.forward(out_ref)
  38. out_our = x
  39. for i in range(num_layers):
  40. out_our = our_ffn(out_our)
  41. assert torch.allclose(out_our, out_ref)
  42. # test grad inputs
  43. obj = (out_ref * (out_ref + 1)).square().mean()
  44. grad_ref, = torch.autograd.grad(obj, x)
  45. obj = (out_our * (out_our + 1)).square().mean()
  46. grad_our, = torch.autograd.grad(obj, x)
  47. assert torch.allclose(grad_ref, grad_our)
  48. # test grad params
  49. x = torch.rand(batch_size, seq_len, dim, device='cpu', requires_grad=True)
  50. out_ref = x
  51. for i in range(num_layers):
  52. out_ref = baseline_ffn.forward(out_ref)
  53. out_our = x
  54. for i in range(num_layers):
  55. out_our = our_ffn(out_our)
  56. obj = (out_ref * (out_ref + 1)).square().mean()
  57. grad_params_ref = torch.autograd.grad(obj, list(baseline_ffn.parameters()))
  58. obj = (out_our * (out_our + 1)).square().mean()
  59. grad_params_our = torch.autograd.grad(obj, list(our_ffn.parameters()))
  60. for grad_ref, grad_our in zip(grad_params_ref, grad_params_our):
  61. assert torch.allclose(grad_ref, grad_our)