Efficient Attention Mechanisms and Memory Optimization in Deep Learning

Attention Mechanisms

Multi-Head Attention

The attention mechanism computes:

The scaling factor \(\sqrt{d_k}\) prevents large inner product values that could cause gradient instability. Assuming Q and K elements have mean 0 and variance \(\sigma^2\), the variance of \(QK^T\) grows with \(d_k\). Scaling by \(\sqrt{d_k}\) maintains stable variance, preventing softmax outputs from becoming too extreme.

Multi-head attention enables the model to capture different types of relationships simultaneously rather than being limited to a single attention pattern.

Residual connections in attention blocks typically follow either Post-Norm (\(Norm(x + f(x))\)) or Pre-Norm (\(x + Norm(f(x))\)) patterns. Some implementation use learnable scaling factors (\(x + \alpha f(x)\)) that gradually increase during training.

Causal Attention

Causal attention restricts each position to attend only to previous positions, preventing information flow from future tokens. This is implemented by masking the attention matrix to block attention to future positions.

Memory Optimization Techniques

Flash Attention

Flash Attention addresses computational efficiency and memory consumption for long sequences by chunking attention computation across GPU memory hierarchies.

The algorithm divides Q, K, and V into blocks that fit in fast SRAM memory. It processes these blocks sequentially, reducing HBM access and improving efficiency. Safe softmax prevents numerical instability during chunked computation.

import torch
from flash_attn import flash_attn_func

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
query = torch.randn(32, 64, 8, 128).to(device, dtype=torch.bfloat16)
output = flash_attn_func(query, query, query, causal=False)
print(output.shape)

Input tensors must have shape (batch_size, seqlen, nheads, headdim) with headdim ≤ 256 and dtype float16 or bfloat16.

Multi-head Latent Attention

MLA compresses KV-cache during inference to reduce memory usage. It projects hidden states to lower dimensions before caching:

class MultiHeadLatentAttention(nn.Module):
    def __init__(self, dim, n_heads, q_rank, kv_rank):
        super().__init__()
        self.n_heads = n_heads
        self.q_proj_down = nn.Linear(dim, q_rank)
        self.q_proj_up = nn.Linear(q_rank, n_heads * head_dim)
        self.kv_proj = nn.Linear(dim, kv_rank)
        self.kv_norm = nn.LayerNorm(kv_rank)
        
    def forward(self, x):
        # Implementation details
        pass

MLA requires careful handling of positional encodings due to the compression operations.

Page Attention

Page Attention manages memory allocation efficiently by organizing KV-cache into fixed-size blocks rather than contiguous memory. This addresses three memory issues:

  • Reserved memory waste
  • Internal memory fragmentation
  • External memory fragmentation

The approach uses a block table to track which memory blocks contain valid data, minimizing waste to only the partially filled final block.

from vllm import LLM, SamplingParams

llm = LLM(model="path/to/model", dtype=torch.float16)
outputs = llm.generate(["Sample prompt"], SamplingParams(temperature=0.8))
print(outputs[0].text)

Tags: attention-mechanism memory-optimization Transformer flash-attention kv-cache

Posted on Thu, 18 Jun 2026 17:39:50 +0000 by bschaeffer