test_rotary.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import torch
  2. from lib.modules.rotary import get_auxiliary_tensors, RotaryEmbeddings
  3. def test_rotary_embeddings():
  4. batch_size = 11
  5. seq_len = 50
  6. nhead = 4
  7. dim = 1024
  8. dtype = torch.float32
  9. device = torch.device('cpu')
  10. base = 10 ** 4
  11. torch.use_deterministic_algorithms(True)
  12. # auxiliary tensors
  13. a, b = get_auxiliary_tensors(seq_len, dim, dtype, device, base)
  14. x, y = Rotary(dim, base).forward(torch.randn(1, seq_len, dim, device=device))
  15. assert torch.allclose(a.view_as(x), x, atol=1e-4, rtol=0)
  16. assert torch.allclose(b.view_as(y), y, atol=1e-4, rtol=0)
  17. # full layer outputs
  18. ref_layer = Rotary(dim, base)
  19. our_layer = RotaryEmbeddings(dim, base)
  20. q = torch.randn(batch_size, seq_len, nhead, dim, dtype=dtype, device=device)
  21. k = torch.randn(batch_size, seq_len, nhead, dim, dtype=dtype, device=device)
  22. q_ref, k_ref = apply_rotary_pos_emb(q.permute(1, 0, 2, 3), k.permute(1, 0, 2, 3), *ref_layer(k))
  23. q_our, k_our = our_layer(q), our_layer(k)
  24. assert torch.allclose(q_ref.permute(1, 0, 2, 3), q_our, atol=1e-4, rtol=0)
  25. assert torch.allclose(k_ref.permute(1, 0, 2, 3), k_our, atol=1e-4, rtol=0)
  26. # check rotation equivariance of dot product
  27. original_dot = (q[0, :, 0] * k[0, :, 0]).sum(-1)
  28. rotated_dot = (our_layer(q)[0, :, 0] * our_layer(k)[0, :, 0]).sum(-1)
  29. assert torch.allclose(original_dot, rotated_dot, atol=1e-4, rtol=0)
  30. class Rotary(torch.nn.Module):
  31. """ Reference implementation by ElutherAI """
  32. def __init__(self, dim, base=10000):
  33. super().__init__()
  34. inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
  35. self.register_buffer("inv_freq", inv_freq)
  36. self.seq_len_cached = None
  37. self.cos_cached = None
  38. self.sin_cached = None
  39. def forward(self, x, seq_dim=1):
  40. seq_len = x.shape[seq_dim]
  41. if seq_len != self.seq_len_cached:
  42. self.seq_len_cached = seq_len
  43. t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
  44. freqs = torch.einsum("i,j->ij", t, self.inv_freq)
  45. emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
  46. self.cos_cached = emb.cos()[:, None, None, :]
  47. self.sin_cached = emb.sin()[:, None, None, :]
  48. return self.cos_cached, self.sin_cached
  49. def rotate_half(x):
  50. x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
  51. return torch.cat(
  52. (-x2, x1), dim=x1.ndim - 1
  53. ) # dim=-1 triggers a bug in torch < 1.8.0
  54. def apply_rotary_pos_emb(q, k, cos, sin):
  55. return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)