Attention Mechanisms
Multi-Head Attention
The attention mechanism computes:
The scaling factor \(\sqrt{d_k}\) prevents large inner product values that could cause gradient instability. Assuming Q and K elements have mean 0 and variance \(\sigma^2\), the variance of \(QK^T\) grows with \(d_k\). Scaling by \(\sqrt{d_k}\) maintains stable variance, preventing softmax outputs from becoming too extreme.
Multi-head attention enables the model to capture different types of relationships simultaneously rather than being limited to a single attention pattern.
Residual connections in attention blocks typically follow either Post-Norm (\(Norm(x + f(x))\)) or Pre-Norm (\(x + Norm(f(x))\)) patterns. Some implementation use learnable scaling factors (\(x + \alpha f(x)\)) that gradually increase during training.
Causal Attention
Causal attention restricts each position to attend only to previous positions, preventing information flow from future tokens. This is implemented by masking the attention matrix to block attention to future positions.
Memory Optimization Techniques
Flash Attention
Flash Attention addresses computational efficiency and memory consumption for long sequences by chunking attention computation across GPU memory hierarchies.
The algorithm divides Q, K, and V into blocks that fit in fast SRAM memory. It processes these blocks sequentially, reducing HBM access and improving efficiency. Safe softmax prevents numerical instability during chunked computation.
import torch
from flash_attn import flash_attn_func
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
query = torch.randn(32, 64, 8, 128).to(device, dtype=torch.bfloat16)
output = flash_attn_func(query, query, query, causal=False)
print(output.shape)
Input tensors must have shape (batch_size, seqlen, nheads, headdim) with headdim ≤ 256 and dtype float16 or bfloat16.
Multi-head Latent Attention
MLA compresses KV-cache during inference to reduce memory usage. It projects hidden states to lower dimensions before caching:
class MultiHeadLatentAttention(nn.Module):
def __init__(self, dim, n_heads, q_rank, kv_rank):
super().__init__()
self.n_heads = n_heads
self.q_proj_down = nn.Linear(dim, q_rank)
self.q_proj_up = nn.Linear(q_rank, n_heads * head_dim)
self.kv_proj = nn.Linear(dim, kv_rank)
self.kv_norm = nn.LayerNorm(kv_rank)
def forward(self, x):
# Implementation details
pass
MLA requires careful handling of positional encodings due to the compression operations.
Page Attention
Page Attention manages memory allocation efficiently by organizing KV-cache into fixed-size blocks rather than contiguous memory. This addresses three memory issues:
- Reserved memory waste
- Internal memory fragmentation
- External memory fragmentation
The approach uses a block table to track which memory blocks contain valid data, minimizing waste to only the partially filled final block.
from vllm import LLM, SamplingParams
llm = LLM(model="path/to/model", dtype=torch.float16)
outputs = llm.generate(["Sample prompt"], SamplingParams(temperature=0.8))
print(outputs[0].text)