Browse Source

Use dalle-pytorch instead of LeanAlbert

Aleksandr Borzunov 3 years ago
parent
commit
20e2a3aab2
9 changed files with 112 additions and 191 deletions
  1. 9 7
      arguments.py
  2. 25 0
      data.py
  3. 0 2
      lib/__init__.py
  4. 4 6
      lib/training/hf_trainer.py
  5. 6 2
      run_aux_peer.py
  6. 5 4
      run_trainer.py
  7. 63 17
      task.py
  8. 0 83
      tests/test_ffn.py
  9. 0 70
      tests/test_rotary.py

+ 9 - 7
arguments.py

@@ -12,15 +12,18 @@ class HFTrainerArguments(TrainingArguments):
     per_device_train_batch_size: int = 1
     per_device_train_batch_size: int = 1
     per_device_eval_batch_size: int = 1
     per_device_eval_batch_size: int = 1
     gradient_accumulation_steps: int = 1
     gradient_accumulation_steps: int = 1
-    seq_length: int = 512
-    pad_to_multiple_of: int = 8
+    text_seq_length: int = 256
+
+    # DALLE-specific params
+    learning_rate: float = 0.003535
+    adam_beta1: float = 0.9
+    adam_beta2: float = 0.96
+    max_grad_norm: float = 4.0
+    weight_decay: float = 0.045
 
 
-    learning_rate: float = 0.0025
     total_steps: int = 31250  # total number of collaborative SGD updates, used for learning rate schedule
     total_steps: int = 31250  # total number of collaborative SGD updates, used for learning rate schedule
     warmup_steps: int = 3125
     warmup_steps: int = 3125
     adam_epsilon: float = 1e-6
     adam_epsilon: float = 1e-6
-    weight_decay: float = 0.01
-    max_grad_norm: float = 1.0
     clamp_value: float = 10000.0
     clamp_value: float = 10000.0
 
 
     fp16: bool = False
     fp16: bool = False
@@ -103,8 +106,7 @@ class CollaborativeArguments:
 class BasePeerArguments:
 class BasePeerArguments:
     """Base arguments that are used for both trainers and for auxiliary peers such as training monitor"""
     """Base arguments that are used for both trainers and for auxiliary peers such as training monitor"""
     experiment_prefix: str = field(default="my-model", metadata={"help": "A unique experiment name, used as prefix for all DHT keys"})
     experiment_prefix: str = field(default="my-model", metadata={"help": "A unique experiment name, used as prefix for all DHT keys"})
-    model_config_path: Optional[str] = field(default="./model.json", metadata={"help": "Path to the model config"})
-    tokenizer_path: Optional[str] = field(default="./tokenizer", metadata={"help": "Path to the tokenizer"})
+    tokenizer_path: Optional[str] = field(default="gpt2", metadata={"help": "Path to the tokenizer"})
     cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"})
     cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"})
 
 
     authorize: bool = field(default=False, metadata={"help": "Whether or not to use HF authorizer"})
     authorize: bool = field(default=False, metadata={"help": "Whether or not to use HF authorizer"})

+ 25 - 0
data.py

@@ -0,0 +1,25 @@
+from typing import Optional
+
+import hivemind
+import numpy as np
+from datasets import load_dataset
+
+logger = hivemind.get_logger(__name__)
+
+
+def make_dataset(
+    tokenizer,
+    *,
+    shuffle_buffer_size: int = 10 ** 4,
+    shuffle_seed: Optional[int],
+    preprocessing_batch_size: int = 256,
+    max_sequence_length: int,
+):
+    ds = load_dataset('laion/laion_100m_vqgan_f8', split='train', streaming=True)
+    ds = ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)
+    ds = ds.map(lambda item: dict(
+        tokenizer(item['caption'], truncation=True, max_length=max_sequence_length),
+        image=np.stack([np.frombuffer(encoded, np.int16).astype(np.int64) for encoded in item['code']]),
+    ), batched=True, batch_size=preprocessing_batch_size)
+    ds = ds.with_format('torch')
+    return ds

+ 0 - 2
lib/__init__.py

@@ -1,2 +0,0 @@
-from .modules import *
-from .models import *

+ 4 - 6
lib/training/hf_trainer.py

@@ -1,15 +1,12 @@
 """A catch-all module for the dirty hacks required to make HF Trainer work with collaborative training"""
 """A catch-all module for the dirty hacks required to make HF Trainer work with collaborative training"""
-from typing import Optional
-
 import torch
 import torch
 from torch import nn
 from torch import nn
 from torch.utils.data import DataLoader
 from torch.utils.data import DataLoader
 from transformers.trainer import Trainer
 from transformers.trainer import Trainer
+from hivemind import CollaborativeOptimizer
+from hivemind.optim import HivemindGradScaler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 
-from lib.staging.collaborative import CollaborativeOptimizer
-from lib.staging.scaler import HivemindGradScaler
-
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger()
 logger = get_logger()
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
@@ -74,7 +71,8 @@ class IgnoreGradManipulations(nn.Module):
         return self.module.forward(*args, **kwargs)
         return self.module.forward(*args, **kwargs)
 
 
     def zero_grad(self, set_to_none: bool = False) -> None:
     def zero_grad(self, set_to_none: bool = False) -> None:
-        if self.override_zero_grad and all(param.grad.isfinite().all() for param in self.parameters()):
+        if self.override_zero_grad and \
+                all(param.grad.isfinite().all() for param in self.parameters() if param.requires_grad):
             logger.debug("Successfully bypassed zero_grad")
             logger.debug("Successfully bypassed zero_grad")
         else:
         else:
             self.module.zero_grad(set_to_none=set_to_none)
             self.module.zero_grad(set_to_none=set_to_none)

+ 6 - 2
run_aux_peer.py

@@ -4,6 +4,7 @@ import time
 
 
 import torch
 import torch
 import wandb
 import wandb
+import transformers
 from transformers import HfArgumentParser
 from transformers import HfArgumentParser
 from huggingface_hub import HfFolder, Repository
 from huggingface_hub import HfFolder, Repository
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -12,8 +13,11 @@ import utils
 from arguments import AuxiliaryPeerArguments, CollaborativeArguments, HFTrainerArguments
 from arguments import AuxiliaryPeerArguments, CollaborativeArguments, HFTrainerArguments
 from task import TrainingTask
 from task import TrainingTask
 
 
+
+transformers.utils.logging.disable_default_handler()
+transformers.utils.logging.enable_propagation()
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
-logger = get_logger()
+logger = get_logger(__name__)
 
 
 
 
 class CheckpointHandler:
 class CheckpointHandler:
@@ -56,7 +60,7 @@ class CheckpointHandler:
 
 
     def upload_checkpoint(self, current_loss):
     def upload_checkpoint(self, current_loss):
         logger.info("Saving model")
         logger.info("Saving model")
-        self.task.model.save_pretrained(self.local_path)
+        torch.save(self.task.model.state_dict(), f"{self.local_path}/model_state.pt")
         logger.info("Saving optimizer")
         logger.info("Saving optimizer")
         torch.save(self.task.collaborative_optimizer.opt.state_dict(), f"{self.local_path}/optimizer_state.pt")
         torch.save(self.task.collaborative_optimizer.opt.state_dict(), f"{self.local_path}/optimizer_state.pt")
         self.previous_timestamp = time.time()
         self.previous_timestamp = time.time()

+ 5 - 4
run_trainer.py

@@ -3,7 +3,6 @@
 import os
 import os
 from pathlib import Path
 from pathlib import Path
 
 
-
 import transformers
 import transformers
 from transformers import HfArgumentParser
 from transformers import HfArgumentParser
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -16,8 +15,10 @@ from arguments import TrainingPeerArguments, HFTrainerArguments, CollaborativeAr
 from task import TrainingTask
 from task import TrainingTask
 
 
 
 
+transformers.utils.logging.disable_default_handler()
+transformers.utils.logging.enable_propagation()
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
-logger = get_logger()
+logger = get_logger(__name__)
 
 
 
 
 def main():
 def main():
@@ -25,8 +26,8 @@ def main():
     training_peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
     training_peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
 
 
     logger.info(f"Trying {len(training_peer_args.initial_peers)} initial peers: {training_peer_args.initial_peers}")
     logger.info(f"Trying {len(training_peer_args.initial_peers)} initial peers: {training_peer_args.initial_peers}")
-    if len(training_peer_args.initial_peers) == 0:
-        logger.warning("Please specify at least one network endpoint in initial peers.")
+    # if len(training_peer_args.initial_peers) == 0:
+    #     logger.warning("Please specify at least one network endpoint in initial peers.")
 
 
     utils.setup_logging(trainer_args)
     utils.setup_logging(trainer_args)
     task = TrainingTask(training_peer_args, trainer_args, collab_args)
     task = TrainingTask(training_peer_args, trainer_args, collab_args)

+ 63 - 17
task.py

@@ -1,23 +1,41 @@
 import os
 import os
 from dataclasses import asdict
 from dataclasses import asdict
+from itertools import cycle, islice
 from pathlib import Path
 from pathlib import Path
 
 
 import hivemind
 import hivemind
+import torch
 import transformers
 import transformers
+from dalle_pytorch import DALLE
+from dalle_pytorch.vae import VQGanVAE, download
 from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
 from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
-from transformers import AlbertTokenizerFast, get_linear_schedule_with_warmup, DataCollatorForLanguageModeling
+from transformers import DataCollatorWithPadding, GPT2TokenizerFast, get_linear_schedule_with_warmup
+from torch import nn
 
 
 import utils
 import utils
 from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
 from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
 from data import make_dataset
 from data import make_dataset
 from huggingface_auth import authorize_with_huggingface
 from huggingface_auth import authorize_with_huggingface
-from lib import LeanAlbertConfig, LeanAlbertForPreTraining
-from lib.staging.collaborative import CollaborativeOptimizer
 from lib.training.clipped_lamb import LambWithGradientClipping
 from lib.training.clipped_lamb import LambWithGradientClipping
 from lib.training.offload import OffloadOptimizer
 from lib.training.offload import OffloadOptimizer
 
 
-hivemind.use_hivemind_log_handler("in_root_logger")
-logger = hivemind.get_logger()
+
+logger = hivemind.get_logger(__name__)
+
+# VQGAN with downsampling factor f=8, 8192 codebook entries, and Gumbel quantization
+# Note: If you change the URLs below, remove ./cache/* to clear the cache
+VQGAN_VAE_PATH = 'https://heibox.uni-heidelberg.de/f/34a747d5765840b5a99d/?dl=1'
+VQGAN_VAE_CONFIG_PATH = 'https://heibox.uni-heidelberg.de/f/b24d14998a8d4f19a34f/?dl=1'
+
+
+class ModelWrapper(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.model = model
+
+    def forward(self, input_ids, attention_mask, image):
+        loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True)
+        return {'loss': loss}
 
 
 
 
 class TrainingTask:
 class TrainingTask:
@@ -30,8 +48,9 @@ class TrainingTask:
         self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
         self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
         self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
         self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
         transformers.set_seed(trainer_args.seed)  # seed used for initialization
         transformers.set_seed(trainer_args.seed)  # seed used for initialization
-        self.config = LeanAlbertConfig.from_pretrained(peer_args.model_config_path)
-        self.tokenizer = AlbertTokenizerFast.from_pretrained(peer_args.tokenizer_path, cache_dir=peer_args.cache_dir)
+
+        self.tokenizer = GPT2TokenizerFast.from_pretrained(peer_args.tokenizer_path)
+        self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
         output_dir = Path(trainer_args.output_dir)
         output_dir = Path(trainer_args.output_dir)
         logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
         logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
@@ -39,11 +58,37 @@ class TrainingTask:
 
 
         if latest_checkpoint_dir is None:
         if latest_checkpoint_dir is None:
             logger.info(f"Creating model")
             logger.info(f"Creating model")
-            self.model = LeanAlbertForPreTraining(self.config)
-            self.model.resize_token_embeddings(len(self.tokenizer))
+
+            vae = VQGanVAE(
+                vqgan_model_path=download(VQGAN_VAE_PATH, 'vqgan.ckpt', root=peer_args.cache_dir),
+                vqgan_config_path=download(VQGAN_VAE_CONFIG_PATH, 'vqgan_config.yaml', root=peer_args.cache_dir),
+            )
+
+            depth = 64
+            attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
+            attn_types.append('conv_like')
+            shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
+            shared_layer_ids.append('w_conv')
+            dalle = DALLE(
+                vae=vae,
+                num_text_tokens=self.tokenizer.vocab_size,
+                text_seq_len=trainer_args.text_seq_length,
+                dim=1024,
+                depth=depth,
+                heads=16,
+                dim_head=64,
+                attn_types=attn_types,
+                ff_dropout=0,
+                attn_dropout=0,
+                shared_attn_ids=shared_layer_ids,
+                shared_ff_ids=shared_layer_ids,
+                rotary_emb=False,  # FIXME: Fix RuntimeError when True
+                reversible=True,
+            )
+            self.model = ModelWrapper(dalle)
         else:
         else:
             logger.info(f"Loading model from {latest_checkpoint_dir}")
             logger.info(f"Loading model from {latest_checkpoint_dir}")
-            self.model = LeanAlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
+            self.task.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
 
 
     @property
     @property
     def dht(self):
     def dht(self):
@@ -72,7 +117,7 @@ class TrainingTask:
             averaging_compression = SizeAdaptiveCompression(
             averaging_compression = SizeAdaptiveCompression(
                 threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
                 threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
             state_compression = hivemind.Float16Compression()
             state_compression = hivemind.Float16Compression()
-            self._collaborative_optimizer = CollaborativeOptimizer(
+            self._collaborative_optimizer = hivemind.CollaborativeOptimizer(
                 dht=self.dht, opt=opt, scheduler=scheduler, prefix=self.peer_args.experiment_prefix,
                 dht=self.dht, opt=opt, scheduler=scheduler, prefix=self.peer_args.experiment_prefix,
                 batch_size_per_step=self.trainer_args.batch_size_per_step,
                 batch_size_per_step=self.trainer_args.batch_size_per_step,
                 compression=averaging_compression, state_compression=state_compression,
                 compression=averaging_compression, state_compression=state_compression,
@@ -83,11 +128,13 @@ class TrainingTask:
         no_decay = ["bias", "LayerNorm.weight"]
         no_decay = ["bias", "LayerNorm.weight"]
         optimizer_grouped_parameters = [
         optimizer_grouped_parameters = [
             {
             {
-                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
+                "params": [p for n, p in self.model.named_parameters()
+                           if not any(nd in n for nd in no_decay) and p.requires_grad],
                 "weight_decay": training_args.weight_decay,
                 "weight_decay": training_args.weight_decay,
             },
             },
             {
             {
-                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
+                "params": [p for n, p in self.model.named_parameters()
+                           if any(nd in n for nd in no_decay) and p.requires_grad],
                 "weight_decay": 0.0,
                 "weight_decay": 0.0,
             },
             },
         ]
         ]
@@ -115,12 +162,11 @@ class TrainingTask:
         if self._training_dataset is None:
         if self._training_dataset is None:
             self._training_dataset = make_dataset(
             self._training_dataset = make_dataset(
                 self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31,
                 self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31,
-                max_sequence_length=self.trainer_args.seq_length
+                max_sequence_length=self.trainer_args.text_seq_length
             )
             )
         return self._training_dataset
         return self._training_dataset
 
 
     @property
     @property
     def data_collator(self):
     def data_collator(self):
-        return DataCollatorForLanguageModeling(
-            tokenizer=self.tokenizer, pad_to_multiple_of=self.trainer_args.pad_to_multiple_of
-        )
+        return DataCollatorWithPadding(tokenizer=self.tokenizer,
+                                       padding='max_length', max_length=self.trainer_args.text_seq_length)

+ 0 - 83
tests/test_ffn.py

@@ -1,83 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from lib.modules.ffn import LeanFFN
-
-
-class ReferenceFFN(nn.Module):
-
-    def __init__(self,
-                 hidden_size: int,
-                 intermediate_size: int,
-                 activation=F.gelu,
-                 layer_norm_eps=1e-12,
-                 dropout: float = 0.0):
-        super().__init__()
-        self.dense_i2h = nn.Linear(hidden_size, intermediate_size)
-        self.dense_h2o = nn.Linear(intermediate_size, hidden_size)
-        self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
-        self.activation = activation
-        self.dropout = dropout
-
-    def forward(self, input):
-        output = self.dense_i2h(self.layer_norm(input))
-        output = self.activation(output)
-        output = self.dense_h2o(output)
-        output = F.dropout(output, self.dropout)
-        return output + input
-
-
-def test_ffn_exact_match():
-    torch.use_deterministic_algorithms(True)
-
-    batch_size = 4
-    seq_len = 128
-    dim = 32
-    num_layers = 4
-
-    baseline_ffn = ReferenceFFN(dim, 4 * dim)
-    our_ffn = LeanFFN(dim, 4 * dim)
-
-    assert our_ffn.load_state_dict(baseline_ffn.state_dict())
-
-    x = torch.rand(batch_size, seq_len, dim, device='cpu', requires_grad=True)
-
-    # test outputs
-    out_ref = x
-    for i in range(num_layers):
-        out_ref = baseline_ffn.forward(out_ref)
-
-    out_our = x
-    for i in range(num_layers):
-        out_our = our_ffn(out_our)
-
-    assert torch.allclose(out_our, out_ref)
-
-    # test grad inputs
-    obj = (out_ref * (out_ref + 1)).square().mean()
-    grad_ref, = torch.autograd.grad(obj, x)
-
-    obj = (out_our * (out_our + 1)).square().mean()
-    grad_our, = torch.autograd.grad(obj, x)
-    assert torch.allclose(grad_ref, grad_our)
-
-    # test grad params
-    x = torch.rand(batch_size, seq_len, dim, device='cpu', requires_grad=True)
-
-    out_ref = x
-    for i in range(num_layers):
-        out_ref = baseline_ffn.forward(out_ref)
-
-    out_our = x
-    for i in range(num_layers):
-        out_our = our_ffn(out_our)
-
-    obj = (out_ref * (out_ref + 1)).square().mean()
-    grad_params_ref = torch.autograd.grad(obj, list(baseline_ffn.parameters()))
-
-    obj = (out_our * (out_our + 1)).square().mean()
-    grad_params_our = torch.autograd.grad(obj, list(our_ffn.parameters()))
-
-    for grad_ref, grad_our in zip(grad_params_ref, grad_params_our):
-        assert torch.allclose(grad_ref, grad_our)

+ 0 - 70
tests/test_rotary.py

@@ -1,70 +0,0 @@
-import torch
-
-from lib.modules.rotary import get_auxiliary_tensors, RotaryEmbeddings
-
-
-def test_rotary_embeddings():
-    batch_size = 11
-    seq_len = 50
-    nhead = 4
-    dim = 1024
-    dtype = torch.float32
-    device = torch.device('cpu')
-    base = 10 ** 4
-
-    torch.use_deterministic_algorithms(True)
-
-    # auxiliary tensors
-    a, b = get_auxiliary_tensors(seq_len, dim, dtype, device, base)
-    x, y = Rotary(dim, base).forward(torch.randn(1, seq_len, dim, device=device))
-    assert torch.allclose(a.view_as(x), x, atol=1e-4, rtol=0)
-    assert torch.allclose(b.view_as(y), y, atol=1e-4, rtol=0)
-
-    # full layer outputs
-    ref_layer = Rotary(dim, base)
-    our_layer = RotaryEmbeddings(dim, base)
-    q = torch.randn(batch_size, seq_len, nhead, dim, dtype=dtype, device=device)
-    k = torch.randn(batch_size, seq_len, nhead, dim, dtype=dtype, device=device)
-
-    q_ref, k_ref = apply_rotary_pos_emb(q.permute(1, 0, 2, 3), k.permute(1, 0, 2, 3), *ref_layer(k))
-    q_our, k_our = our_layer(q), our_layer(k)
-    assert torch.allclose(q_ref.permute(1, 0, 2, 3), q_our, atol=1e-4, rtol=0)
-    assert torch.allclose(k_ref.permute(1, 0, 2, 3), k_our, atol=1e-4, rtol=0)
-
-    # check rotation equivariance of dot product
-    original_dot = (q[0, :, 0] * k[0, :, 0]).sum(-1)
-    rotated_dot = (our_layer(q)[0, :, 0] * our_layer(k)[0, :, 0]).sum(-1)
-    assert torch.allclose(original_dot, rotated_dot, atol=1e-4, rtol=0)
-
-
-class Rotary(torch.nn.Module):
-    """ Reference implementation by ElutherAI """
-    def __init__(self, dim, base=10000):
-        super().__init__()
-        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
-        self.register_buffer("inv_freq", inv_freq)
-        self.seq_len_cached = None
-        self.cos_cached = None
-        self.sin_cached = None
-
-    def forward(self, x, seq_dim=1):
-        seq_len = x.shape[seq_dim]
-        if seq_len != self.seq_len_cached:
-            self.seq_len_cached = seq_len
-            t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
-            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
-            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
-            self.cos_cached = emb.cos()[:, None, None, :]
-            self.sin_cached = emb.sin()[:, None, None, :]
-        return self.cos_cached, self.sin_cached
-
-
-def rotate_half(x):
-    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
-    return torch.cat(
-        (-x2, x1), dim=x1.ndim - 1
-    )  # dim=-1 triggers a bug in torch < 1.8.0
-
-
-def apply_rotary_pos_emb(q, k, cos, sin):
-    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)