12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- import torch
- from lib.modules.rotary import get_auxiliary_tensors, RotaryEmbeddings
- def test_rotary_embeddings():
- batch_size = 11
- seq_len = 50
- nhead = 4
- dim = 1024
- dtype = torch.float32
- device = torch.device('cpu')
- base = 10 ** 4
- torch.use_deterministic_algorithms(True)
- # auxiliary tensors
- a, b = get_auxiliary_tensors(seq_len, dim, dtype, device, base)
- x, y = Rotary(dim, base).forward(torch.randn(1, seq_len, dim, device=device))
- assert torch.allclose(a.view_as(x), x, atol=1e-4, rtol=0)
- assert torch.allclose(b.view_as(y), y, atol=1e-4, rtol=0)
- # full layer outputs
- ref_layer = Rotary(dim, base)
- our_layer = RotaryEmbeddings(dim, base)
- q = torch.randn(batch_size, seq_len, nhead, dim, dtype=dtype, device=device)
- k = torch.randn(batch_size, seq_len, nhead, dim, dtype=dtype, device=device)
- q_ref, k_ref = apply_rotary_pos_emb(q.permute(1, 0, 2, 3), k.permute(1, 0, 2, 3), *ref_layer(k))
- q_our, k_our = our_layer(q), our_layer(k)
- assert torch.allclose(q_ref.permute(1, 0, 2, 3), q_our, atol=1e-4, rtol=0)
- assert torch.allclose(k_ref.permute(1, 0, 2, 3), k_our, atol=1e-4, rtol=0)
- # check rotation equivariance of dot product
- original_dot = (q[0, :, 0] * k[0, :, 0]).sum(-1)
- rotated_dot = (our_layer(q)[0, :, 0] * our_layer(k)[0, :, 0]).sum(-1)
- assert torch.allclose(original_dot, rotated_dot, atol=1e-4, rtol=0)
- class Rotary(torch.nn.Module):
- """ Reference implementation by ElutherAI """
- def __init__(self, dim, base=10000):
- super().__init__()
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
- self.register_buffer("inv_freq", inv_freq)
- self.seq_len_cached = None
- self.cos_cached = None
- self.sin_cached = None
- def forward(self, x, seq_dim=1):
- seq_len = x.shape[seq_dim]
- if seq_len != self.seq_len_cached:
- self.seq_len_cached = seq_len
- t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
- self.cos_cached = emb.cos()[:, None, None, :]
- self.sin_cached = emb.sin()[:, None, None, :]
- return self.cos_cached, self.sin_cached
- def rotate_half(x):
- x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
- return torch.cat(
- (-x2, x1), dim=x1.ndim - 1
- ) # dim=-1 triggers a bug in torch < 1.8.0
- def apply_rotary_pos_emb(q, k, cos, sin):
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|