Explorar el Código

WIP Triton+QKV merge

Max Ryabinin hace 1 año
padre
commit
fa464dfc99

+ 116 - 2
src/petals/models/llama/block.py

@@ -6,10 +6,124 @@ See commit history for authorship.
 from typing import Optional, Tuple
 
 import torch
-from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from transformers.models.llama.modeling_llama import (
+    LlamaAttention,
+    LlamaConfig,
+    LlamaDecoderLayer,
+    LlamaMLP,
+    LlamaModel,
+    LlamaRMSNorm,
+    repeat_kv,
+    apply_rotary_pos_emb,
+)
 
+from petals.triton import attention_triton_wrapper, rbe_triton_wrapper, rmsnorm_triton_wrapper
 
-class WrappedLlamaBlock(LlamaDecoderLayer):
+
+class OptimizedLlamaRMSNorm(LlamaRMSNorm):
+    def forward(self, hidden_states):
+        if torch.is_inference_mode_enabled():
+            return rmsnorm_triton_wrapper(hidden_states, self.weight)
+        return super().forward(hidden_states)
+
+
+class OptimizedLlamaAttention(LlamaAttention):
+    def __init__(self, config: LlamaConfig):
+        super().__init__(config)
+        self.qkv_proj = nn.Linear(
+            self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False
+        )
+        self.qkv_sizes = [
+            self.num_heads * self.head_dim,
+            self.num_key_value_heads * self.head_dim,
+            self.num_key_value_heads * self.head_dim,
+        ]
+        self.attn_norm_constant = math.sqrt(self.head_dim)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        bsz, q_len, _ = hidden_states.size()
+        assert (
+            self.config.pretraining_tp == 1
+        ), "OptimizedLlamaAttention assumes that config.pretraining_tp is equal to 1"
+
+        query_states, key_states, value_states = torch.split(self.qkv_proj(hidden_states), self.qkv_sizes, dim=2)
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        kv_seq_len = key_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value[0].shape[-2]
+        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+        if past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+        past_key_value = (key_states, value_states) if use_cache else None
+
+        # repeat k/v heads if n_kv_heads < n_heads
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.attn_norm_constant
+
+        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights + attention_mask
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+        attn_output = self.o_proj(attn_output)
+
+        return attn_output, None, past_key_value
+
+
+class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
+    def __init__(self, config: LlamaConfig):
+        nn.Module.__init__(self)
+        self.hidden_size = config.hidden_size
+        self.self_attn = OptimizedLlamaAttention(config=config)
+        self.mlp = LlamaMLP(config)
+        self.input_layernorm = OptimizedLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.post_attention_layernorm = OptimizedLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+
+class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
     def forward(
         self,
         hidden_states: torch.Tensor,

+ 114 - 4
src/petals/server/throughput.py

@@ -1,6 +1,8 @@
+from __future__ import annotations
+
+import argparse
 import fcntl
 import json
-import math
 import multiprocessing as mp
 import os
 import time
@@ -8,14 +10,19 @@ from collections import Counter
 from pathlib import Path
 from typing import Dict, Optional, Sequence, Union
 
+import configargparse
 import torch
+
 import torch.mps
 from hivemind.utils.logging import get_logger
 from transformers import PretrainedConfig
 
+from petals.constants import DTYPE_MAP
 from petals.server.block_utils import resolve_block_dtype
+from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
+from petals.utils.version import get_compatible_model_repo
 
 logger = get_logger(__name__)
 
@@ -114,6 +121,7 @@ def measure_throughput_info(
     *,
     quant_type: QuantType,
     tensor_parallel_devices: Sequence[torch.device],
+    measure_network: bool = True,
 ) -> Dict[str, float]:
     logger.info(
         "Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
@@ -139,14 +147,16 @@ def measure_throughput_info(
             n_steps=10,
             inference=False,
         ),
-        "network_rps": measure_network_rps(config),
+        "network_rps": measure_network_rps(config, use_default=not measure_network),
     }
 
 
 def measure_network_rps(
-    config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 100e6  # 100 Mbit/s
+    config: PretrainedConfig, *, use_default=False, timeout: float = 60, default_speed: float = 100e6  # 100 Mbit/s
 ) -> Optional[float]:
     bits_per_request = config.hidden_size * 16  # Clients usually send 16-bit tensors for forward/backward
+    if use_default:
+        return default_speed / bits_per_request
     try:
         pipe_recv, pipe_send = mp.Pipe(duplex=False)
         process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
@@ -207,13 +217,23 @@ def measure_compute_rps(
         cache = None
         elapsed = 0
         dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
+        # with torch.profiler.profile(
+        #         schedule=torch.profiler.schedule(wait=1, warmup=4, active=n_steps, repeat=1),
+        #         on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profbf16_70b_qkv'),
+        #         record_shapes=True,
+        #         profile_memory=True,
+        #         with_stack=True
+        # ) as prof:
         _, cache = block.forward(dummy_input, use_cache=True)  # Skip the 1st step to exclude the initialization time
         synchronize(device)
+        # prof.step()
 
         start_time = time.perf_counter()
         for _ in range(n_steps):
             _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
-        synchronize(device)
+            synchronize(device)
+            # prof.step()
+           
         elapsed = time.perf_counter() - start_time
         device_rps = n_steps * n_tokens / elapsed
 
@@ -245,3 +265,93 @@ def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:
     if quant_type != QuantType.NONE:
         name += f", quantized to {quant_type.name.lower()}"
     return name
+
+
+def parse_args():
+    # fmt:off
+    parser = configargparse.ArgParser(default_config_files=["config.yml"],
+                                      formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
+
+    group = parser.add_mutually_exclusive_group(required=True)
+    group.add_argument('--converted_model_name_or_path', type=str, default=None,
+                       help="path or name of a pretrained model, converted with cli/convert_model.py")
+    group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
+
+    group = parser.add_mutually_exclusive_group(required=False)
+    group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()")
+    group.add_argument("--use_auth_token", action="store_true", dest="token",
+                       help="Read token saved by `huggingface-cli login")
+
+    parser.add_argument('--device', type=str, default=None, required=False,
+                        help='all blocks will use this device in torch notation; default: cuda if available else cpu')
+    parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
+                        help="Use this dtype to store block weights and do computations. "
+                             "By default, respect the dtypes in the pre-trained state dict.")
+    parser.add_argument('--revision', type=str, default=None,
+                        help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
+                             "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
+
+    parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
+                        help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
+                             "4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "
+                             "Default: 'int8' if GPU is available, 'none' otherwise")
+    parser.add_argument("--tensor_parallel_devices", nargs='+', default=None,
+                        help=
+                        "Split each block between the specified GPUs such that each device holds a portion of every "
+                        "weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism")
+
+    # fmt:on
+    args = parser.parse_args()
+    args.converted_model_name_or_path = args.model
+    return args
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    converted_model_name_or_path = get_compatible_model_repo(args.converted_model_name_or_path)
+    config = AutoDistributedConfig.from_pretrained(
+        converted_model_name_or_path,
+        use_auth_token=args.token,
+        revision=args.revision,
+    )
+
+    device = args.device
+    if device is None:
+        if torch.cuda.is_available():
+            device = "cuda"
+        elif torch.backends.mps.is_available():
+            device = "mps"
+        else:
+            device = "cpu"
+    device = torch.device(device)
+    if device.type == "cuda" and device.index is None:
+        device = torch.device(device.type, index=0)
+
+    torch_dtype = resolve_block_dtype(config, DTYPE_MAP[args.torch_dtype])
+    if device.type == "cpu" and torch_dtype == torch.float16:
+        raise ValueError(
+            f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16"
+        )
+    if device.type == "mps" and torch_dtype == torch.bfloat16:
+        logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead")
+        torch_dtype = torch.float16
+
+    quant_type = args.quant_type
+    if quant_type is None:
+        if device.type == "cuda":
+            quant_type = QuantType.NF4 if config.model_type == "llama" else QuantType.INT8
+        else:
+            quant_type = QuantType.NONE
+
+    if args.tensor_parallel_devices is None:
+        args.tensor_parallel_devices = (device,)
+
+    measure_throughput_info(
+        config,
+        device,
+        torch_dtype,
+        quant_type=quant_type,
+        tensor_parallel_devices=args.tensor_parallel_devices,
+        measure_network=False,
+    )

+ 3 - 0
src/petals/triton/__init__.py

@@ -0,0 +1,3 @@
+from petals.triton.rmsnorm import rmsnorm_triton_wrapper
+from petals.triton.attention import attention_triton_wrapper
+from petals.triton.rotary import rbe_triton_wrapper

+ 178 - 0
src/petals/triton/attention.py

@@ -0,0 +1,178 @@
+import math
+
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _fwd_kernel(
+    Q,
+    K,
+    V,
+    sm_scale,
+    Out,
+    stride_qz,
+    stride_qh,
+    stride_qm,
+    stride_qk,
+    stride_kz,
+    stride_kh,
+    stride_kn,
+    stride_kk,
+    stride_vz,
+    stride_vh,
+    stride_vk,
+    stride_vn,
+    stride_oz,
+    stride_oh,
+    stride_om,
+    stride_on,
+    N_HEAD,
+    H,
+    N_CTX,
+    start_position,  # <- ADDED
+    IS_CAUSAL: tl.constexpr,  # <- ADDED
+    BLOCK_M: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    BLOCK_DMODEL: tl.constexpr,
+    USE_FP8: tl.constexpr,
+):
+    start_m = tl.program_id(0)
+
+    head_idx = tl.program_id(1)
+    batch_id = head_idx // N_HEAD
+    off_hz = head_idx % N_HEAD
+
+    # initialize offsets
+    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    offs_n = tl.arange(0, BLOCK_N)
+    offs_d = tl.arange(0, BLOCK_DMODEL)
+    off_q = (
+        batch_id * stride_qz + off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
+    )  # <- stride fixed
+    off_k = (
+        batch_id * stride_kz + off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
+    )  # <- stride fixed
+    off_v = (
+        batch_id * stride_vz + off_hz * stride_vh + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn
+    )  # <- stride fixed
+    # Initialize pointers to Q, K, V
+    q_ptrs = Q + off_q
+    k_ptrs = K + off_k
+    v_ptrs = V + off_v
+    # initialize pointer to m and l
+    m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+    l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
+    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+    # load q: it will stay in SRAM throughout
+    q = tl.load(q_ptrs, offs_m[:, None] < H, other=0.0)
+    # loop over k, v and update accumulator
+    block_n_end = N_CTX  # <- ADDED (including the IF)
+    if IS_CAUSAL:
+        # in causal mode, we expect that BLOCK_M_SIZE == BLOCK_N_SIZE
+        # autotune will prune shapes not matching this rule
+        block_n_end = (start_m + 1) * BLOCK_N + start_position
+    for start_n in range(0, block_n_end, BLOCK_N):
+        block_n_offs = start_n + offs_n  # <- ADDED
+        # -- compute qk ----
+        k = tl.load(k_ptrs, block_n_offs[:, None] < N_CTX, 0.0)
+        if USE_FP8:
+            k = k.to(tl.float8e5, bitcast=True)
+            k = k.to(tl.float16)
+        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+        qk += tl.dot(q, tl.trans(k))
+        qk = tl.where(offs_n[None, :] < N_CTX, qk, float("-inf"))  # <- ADDED
+        qk *= sm_scale
+        if IS_CAUSAL:  # <- ADDED
+            qk = tl.where(offs_m[:, None] >= (block_n_offs[None, :] + start_position), qk, float("-inf"))
+
+        # compute new m
+        m_curr = tl.maximum(tl.max(qk, 1), m_prev)
+        # correct old l
+        l_prev *= tl.exp(m_prev - m_curr)
+        # attention weights
+        p = tl.exp(qk - m_curr[:, None])
+        l_curr = tl.sum(p, 1) + l_prev
+        # rescale operands of matmuls
+        l_rcp = 1.0 / l_curr
+        p *= l_rcp[:, None]
+        acc *= (l_prev * l_rcp)[:, None]
+        # update acc
+        p = p.to(Q.dtype.element_ty)
+        v = tl.load(v_ptrs, block_n_offs[:, None] < N_CTX, 0.0)
+        if USE_FP8:
+            v = v.to(tl.float8e5, bitcast=True)
+            v = v.to(tl.float16)
+        acc += tl.dot(p, v)
+        # update m_i and l_i
+        l_prev = l_curr
+        m_prev = m_curr
+        # update pointers
+        k_ptrs += BLOCK_N * stride_kn
+        v_ptrs += BLOCK_N * stride_vk
+    # rematerialize offsets to save registers
+    start_m = tl.program_id(0)
+    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+
+    # initialize pointers to output
+    offs_d = tl.arange(0, BLOCK_DMODEL)
+    off_o = batch_id * stride_oz + off_hz * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
+    out_ptrs = Out + off_o
+    tl.store(out_ptrs, acc, offs_m[:, None] < H)
+
+
+def triton_fa(q, k, v, sm_scale, is_causal, start_position):
+    assert q.dtype == torch.float16
+    assert k.dtype == v.dtype and k.dtype in [torch.float16, torch.int8]
+
+    BLOCK = 64
+    # shape constraints
+    Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
+    assert Lq == Lk and Lk == Lv
+    assert Lk in {16, 32, 64, 128}
+    o = torch.empty_like(q)
+    num_warps = 4 if Lk <= 64 else 8
+    batch, head_size, m_size, dhead = q.size()
+    grid = (triton.cdiv(m_size, BLOCK), head_size * batch)
+    n_size = k.size(2)
+    _fwd_kernel[grid](
+        q,
+        k,
+        v,
+        sm_scale,
+        o,
+        q.stride(0),
+        q.stride(1),
+        q.stride(2),
+        q.stride(3),
+        k.stride(0),
+        k.stride(1),
+        k.stride(2),
+        k.stride(3),
+        v.stride(0),
+        v.stride(1),
+        v.stride(2),
+        v.stride(3),
+        o.stride(0),
+        o.stride(1),
+        o.stride(2),
+        o.stride(3),
+        head_size,
+        m_size,
+        n_size,
+        start_position=start_position,
+        IS_CAUSAL=is_causal,
+        BLOCK_M=BLOCK,
+        BLOCK_N=BLOCK,
+        BLOCK_DMODEL=Lk,
+        USE_FP8=k.dtype == torch.int8,  # USE_FP8
+        num_warps=num_warps,
+        num_stages=2,
+    )
+
+    return o
+
+
+def attention_triton_wrapper(q, k, v, head_dim):
+    return triton_fa(q, k, v, sm_scale=1 / math.sqrt(head_dim), is_causal=True, start_position=0)

+ 49 - 0
src/petals/triton/rmsnorm.py

@@ -0,0 +1,49 @@
+import triton
+import triton.language as tl
+import torch
+
+@triton.jit
+def rmsnorm_triton(x_ptr, rms_w_ptr, output_ptr,
+                   stride_x_batch, stride_x_m, stride_x_k,
+                   stride_rms_w,
+                   stride_out_batch, stride_out_m, stride_out_k,
+                   N_SIZE: tl.constexpr, eps: tl.constexpr, BLOCK_N_SIZE: tl.constexpr):
+    pid_batch = tl.program_id(0)
+    pid_m = tl.program_id(1)
+
+    offs_m = pid_batch * stride_x_batch + pid_m * stride_x_m
+    block_N = tl.arange(0, BLOCK_N_SIZE)
+    var = tl.zeros((BLOCK_N_SIZE,), tl.float32)
+    for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):
+        offs_n = block_n_start_idx + block_N
+        x_ptr_mask = offs_n < N_SIZE
+        x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0)
+        var += tl.math.pow(x.to(tl.float32), 2)
+
+    var = tl.sum(var, axis=0) / N_SIZE
+    rstd = tl.math.rsqrt(var + eps)
+
+    # multiply by weight and add bias
+    for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):
+        offs_n = block_n_start_idx + block_N
+        x_ptr_mask = offs_n < N_SIZE
+        rms_w = tl.load(rms_w_ptr + offs_n * stride_rms_w, mask=x_ptr_mask)
+
+        x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0).to(tl.float32)
+        x_hat = x * rstd
+        out = x_hat * rms_w
+        out_off = pid_batch * stride_out_batch + pid_m * stride_out_m + offs_n * stride_out_k
+        tl.store(output_ptr + out_off, out, mask=x_ptr_mask)
+
+
+def rmsnorm_triton_wrapper(x, rms_w, eps=1e-6):
+    batch_size, seq_length, hid_dim = x.shape
+    assert rms_w.shape[-1] == hid_dim
+    out = torch.empty_like(x)
+    rmsnorm_triton[(batch_size, seq_length,)](x, rms_w, out,
+                                *x.stride(),
+                                *rms_w.stride(),
+                                *out.stride(),
+                                N_SIZE=hid_dim, eps=eps, BLOCK_N_SIZE=1024,
+                                )
+    return out

+ 81 - 0
src/petals/triton/rotary.py

@@ -0,0 +1,81 @@
+import triton
+import triton.language as tl
+import torch
+
+
+@triton.jit
+def get_freq_multi_tokens(offs_cn, starting_idx, theta: tl.constexpr, NB_TOKENS: tl.constexpr):
+    DIM: tl.constexpr = 128  # in model, dim = self.params.dim // self.params.n_heads
+    freqs = offs_cn % DIM
+    freqs = freqs.to(tl.float32) / DIM
+    freqs = tl.math.pow(theta, freqs)
+    freqs = (tl.arange(0, NB_TOKENS) + starting_idx)[:, None] / freqs[None, :]
+    return tl.cos(freqs), tl.sin(freqs)
+
+
+@triton.jit
+def rbe_triton(
+    x_ptr,
+    out_ptr,
+    M,
+    K,
+    stride_x_batch,
+    stride_x_m,
+    stride_x_n,
+    stride_out_batch,
+    stride_out_m,
+    stride_out_n,
+    start_token_position,
+    THETA: tl.constexpr,
+    BLOCK_SIZE_M: tl.constexpr,
+    BLOCK_SIZE_K: tl.constexpr,
+):
+    pid_batch = tl.program_id(axis=0)
+    pid = tl.program_id(axis=1)
+    pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K)
+    pid_n = pid % tl.cdiv(K, BLOCK_SIZE_K)
+
+    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+    offs_n = pid_n * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K // 2) * 2  # take only even numbers
+    x_ptrs = x_ptr + (pid_batch * stride_x_batch + stride_x_m * offs_m[:, None] + stride_x_n * offs_n[None, :])
+    x_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
+    real = tl.load(x_ptrs, mask=x_real_mask, other=0.0)
+    x_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)
+    imag = tl.load(x_ptrs + 1, mask=x_imag_mask, other=0.0)
+    tl.debug_barrier()
+    start_block = start_token_position + pid_m * BLOCK_SIZE_M
+    cos, sin = get_freq_multi_tokens(offs_cn=offs_n, starting_idx=start_block, theta=THETA, NB_TOKENS=BLOCK_SIZE_M)
+
+    out_real = real * cos - imag * sin
+    out_imag = real * sin + imag * cos
+    tl.debug_barrier()
+    out_ptrs = out_ptr + (
+        pid_batch * stride_out_batch + stride_out_m * offs_m[:, None] + stride_out_n * offs_n[None, :]
+    )
+    out_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
+    tl.store(out_ptrs, out_real, mask=out_real_mask)
+    out_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)
+    tl.store(out_ptrs + 1, out_imag, mask=out_imag_mask)
+
+
+def rbe_triton_wrapper(x: torch.Tensor, pos: int) -> torch.Tensor:
+    batch, M, K = x.shape
+    out = torch.empty_like(x)
+    grid = lambda META: (
+        batch,
+        triton.cdiv(META["M"], META["BLOCK_SIZE_M"]) * triton.cdiv(META["K"], META["BLOCK_SIZE_K"]),
+    )
+
+    rbe_triton[grid](
+        x,
+        out,
+        M,
+        K,
+        *x.stride(),
+        *out.stride(),
+        start_token_position=pos,
+        THETA=10000.0,
+        BLOCK_SIZE_M=2,
+        BLOCK_SIZE_K=1024
+    )
+    return out

+ 13 - 0
src/petals/utils/convert_block.py

@@ -50,6 +50,19 @@ def convert_block(
     if freeze:
         block.requires_grad_(False)
 
+    if hasattr(block, "self_attn") and hasattr(block.self_attn, "qkv_proj"):
+        offset = 0
+        for data in [
+            block.self_attn.q_proj.weight.data,
+            block.self_attn.k_proj.weight.data,
+            block.self_attn.v_proj.weight.data,
+        ]:
+            block.self_attn.qkv_proj.weight.data[offset : offset + data.size(0)].copy_(data)
+            offset += data.size(0)
+        del block.self_attn.q_proj
+        del block.self_attn.k_proj
+        del block.self_attn.v_proj
+
     block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
 
     if quant_type != QuantType.NONE: