Implementing and Optimizing PagedAttention Kernels in vLLM

PagedAttention Memory Layout and Block Mapping

PagedAttention replaces traditional contiguous key-value cache allocations with a virtual-to-physical block mapping scheme. This approach mirrors operating system memory paging, allowing non-contiguous GPU memory segments to serve sequential generation tasks without fragmentation overhead. Each request maintains a block table that indexes physical cache slots, enabling dynamic growth and efficient batching across heterogeneous sequence lengths.

vLLM Kernel Interfaces: V1 vs V2 Dispatch

The inference engine exposes optimized CUDA kernels through low-level custom operations. The initial variant processes attention scores in a single pass, calculating softmax normalization alongside value accumulation. The updated variant introduces a partition-aware reduction strategy that improves numeriacl stability and reduces shared memory pressure for long sequences. Intermediate buffers store per-partition exponentials and maximum logit values before a final reduction step.

import torch
from typing import List, Tuple
from vllm._custom_ops import paged_attention_v1 as aten_v1
from vllm._custom_ops import paged_attention_v2 as aten_v2
from vllm.cache_ops import allocate_and_initialize_kv_caches

# Configuration constants
CACHE_SLOT_COUNT = 4096
SEQUENCE_SLICE_WIDTH = 1024
DEFAULT_HEAD_DIM = 96

def execute_attention_pass(
    backend_version: str,
    batch_count: int,
    head_config: Tuple[int, int],
    d_model: int,
    enable_alibi: bool,
    slot_size: int,
    compute_dtype: torch.dtype,
    cache_backend: str,
    seed_val: int,
    target_device: str,
) -> torch.Tensor:
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)
    torch.set_default_device(target_device)

    query_scaling_factor = d_model ** -0.5
    q_heads, k_heads = head_config
    assert q_heads % k_heads == 0, "Head counts must support grouping"

    # Initialize query tensor with scaled uniform distribution
    queries = torch.empty(batch_count, q_heads, d_model, dtype=compute_dtype, device=target_device)
    queries.uniform_(-query_scaling_factor, query_scaling_factor)

    # Generate variable sequence lengths
    raw_lengths = [torch.randint(1, CACHE_SLOT_COUNT // slot_size + 1, (1,)).item() for _ in range(batch_count)]
    raw_lengths[-1] = CACHE_SLOT_COUNT // slot_size
    seq_lens_tensor = torch.tensor(raw_lengths, dtype=torch.int32, device=target_device)

    # Construct block routing tables
    max_slots_per_seq = (seq_lens_tensor.max().item() + slot_size - 1) // slot_size
    routing_tables = [
        torch.randint(0, CACHE_SLOT_COUNT, (max_slots_per_seq,), dtype=torch.int32, device=target_device)
        for _ in range(batch_count)
    ]
    block_map = torch.stack(routing_tables)

    # Allocate KV storage
    keys, values = allocate_and_initialize_kv_caches(
        num_blocks=CACHE_SLOT_COUNT,
        block_size=slot_size,
        cache_shape=(1, k_heads, d_model),
        backend_dtype=cache_backend,
        data_dtype=compute_dtype,
        seed=seed_val,
        device=target_device,
    )
    k_cache, v_cache = keys[0], values[0]

    # Prepare optional Alibi slopes
    alibi_coeffs = None
    if enable_alibi:
        alibi_coeffs = torch.randn(q_heads, dtype=torch.float32, device=target_device)

    # Execute selected kernel variant
    result = torch.empty_like(queries)
    if backend_version == "variant_one":
        aten_v1(
            result, queries, k_cache, v_cache, k_heads,
            query_scaling_factor, block_map, seq_lens_tensor,
            slot_size, seq_lens_tensor.max().item(), alibi_coeffs,
            cache_backend, 1.0
        )
    elif backend_version == "variant_two":
        partitions = (seq_lens_tensor.max().item() + SEQUENCE_SLICE_WIDTH - 1) // SEQUENCE_SLICE_WIDTH
        assert SEQUENCE_SLICE_WIDTH % slot_size == 0
        
        tmp_buf = torch.empty(batch_count, q_heads, partitions, d_model, dtype=result.dtype, device=target_device)
        exp_sums = torch.empty(batch_count, q_heads, partitions, dtype=torch.float32, device=target_device)
        peak_logits = torch.empty(batch_count, q_heads, partitions, dtype=torch.float32, device=target_device)

        aten_v2(
            result, exp_sums, peak_logits, tmp_buf, queries,
            k_cache, v_cache, k_heads, query_scaling_factor,
            block_map, seq_lens_tensor, slot_size,
            seq_lens_tensor.max().item(), alibi_coeffs,
            cache_backend, 1.0
        )
    else:
        raise ValueError(f"Unsupported dispatch mode: {backend_version}")

    return result

# Example execution
final_outputs = execute_attention_pass(
    backend_version="variant_two",
    batch_count=8,
    head_config=(32, 8),
    d_model=128,
    enable_alibi=False,
    slot_size=16,
    compute_dtype=torch.bfloat16,
    cache_backend="auto",
    seed_val=42,
    target_device="cuda:0",
)

FlashAttention Integration with Virtualized Caches

When compute utilization becomes the primary bottleneck, swapping the standard attention routine with FlashAttention yields substantial throughput gains. The library operates directly on block-addressed memory layouts, eliminating redundant memory bandwidth consumption. By passing pre-allocated key and value buffers alongside sequence length metadata, the kernel performs tiled matrix multiplications entirely within registers and shared memory.

import torch
from vllm_flash_attn import flash_attn_with_kvcache as fa_kv

# FlashAttention configuration
TARGET_DEVICE = "cuda"
D_MODEL = 128
Q_HEADS = 32
KV_HEADS = 8
SLICE_WIDTH = 16
NUM_REQUESTS = 5
SEED_VALUE = 101

torch.set_default_device(TARGET_DEVICE)
torch.manual_seed(SEED_VALUE)

# Simulate variable sequence traces
trace_lengths = [1250, 47, 892, 310, MAX_LEN]
MAX_LEN = 2048
cu_seqlens_kv = torch.cumsum(torch.tensor([0] + trace_lengths, dtype=torch.int32), dim=0).to(TARGET_DEVICE)

# Allocate compacted KV storage matching FlashAttention expectations
kv_capacity = sum(trace_lengths) * D_MODEL
key_buffer = torch.randn(kv_capacity // D_MODEL, KV_HEADS, D_MODEL, dtype=torch.float16, device=TARGET_DEVICE)
val_buffer = torch.randn(kv_capacity // D_MODEL, KV_HEADS, D_MODEL, dtype=torch.float16, device=TARGET_DEVICE)

# Query tensor aligned to batch dimension
query_tensor = torch.randn(NUM_REQUESTS, Q_HEADS, D_MODEL, dtype=torch.float16, device=TARGET_DEVICE)

# Compute scaling factor
softmax_scale = (D_MODEL ** -0.5)

# Execute FlashAttention with pre-cached keys/values
output_tensor = fa_kv(
    key_buffer, val_buffer, query_tensor, cu_seqlens_kv,
    MAX_LEN, None,  # Key/Value are already provided; max_len overrides padding
    softmax_scale,
    causal=True,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
)

print(output_tensor.shape)

Hardware Profiling and Register Optimization

Identifying performance ceilings requires analyzing warps occupancy, shared memory bank conflicts, and instruction throughput. Nsight Compute captures cycle-accurate metrics across tensor cores, highlighting whether kernels are bound by arithmetic intensity or memory latency. Within the vLLM ecosystem, internal event hooks allow developers to instrument kernel launches directly. By sampling register allocation per thread and tracking L2 cache hit rates, engineers can tune tile dimensions or adjust block sizes to maximize pipeline utilization with out exceeding hardware limits.

Tags: paged-attention flash-attention vLLM cuda-kernels llm-inference

Posted on Wed, 20 May 2026 06:06:02 +0000 by LarryK