Transformer Model Architecture and Computational Analysis

Model Structure

The basic unit consists of token embedding with positional encoding, encoder, and decoder.

  • Encoder: Self-attention layer with skip connections and layer normalization, followed by a feed-forward network (FFN) with skip connections and layer normalization.
  • Decoder: Self-attention layer with skip connections and layer normalization, cross-attention layer with skip connections and layer normalization, and a feed-forward network with skip connections and layer normalization.

Computational Complexity and Parameters

Before training a large model, scaling laws are used to estimate data requirements, computational power (FLOPs), and training time.

In deep learning, each forward pass involves matrix multiplications. For a single parameter, this typically requires one addition and one multiplication operation, counted as two floating-point operations (FLOPs).

The following conventions are adopted:

  • L: Number of Transformer layers.
  • H: Model hidden dimension (d_model).
  • h: Number of attention heads.
  • B: Batch size.
  • S: Sequence length (e.g., 2K for GPT, 4K for LLaMA).
  • V: Vocabulary size.

Inference Memory

Memory consumption is primarily from model parameters, KV cache, and intermediate activations. Intermediate results typically contribute a smaller portion: batch size × token length × embedding size.

Attention Mechanism

Given input embeddings of shape [B, S, H]:

  1. Project Q, K, V matrices via dense layers: [B, S, H] × [H, H] = [B, S, H]. Computational cost: 3 × B × S × H² FLOPs.
  2. Compute attention scores: softmax(Q × Kᵀ / √(d_model)). Shape: [B, h, S, H'] × [B, h, H', S] = [B, h, S, S]. Computational cost: B × h × S² FLOPs.
  3. Multiply with V: [B, h, S, S] × [B, h, S, H'] = [B, h, S, H']. Computational cost: B × h × S² × H' FLOPs.
  4. Project back via output dense layer: [B, h, S, H'] × [H', H'] = [B, S, H]. Computational cost: B × S × H² FLOPs.

Total FLOPs for Attention: 8 × B × S × H² + 4 × B × H × S². The factor of 2 accounts for the addition and multiplication per operation.

Feed-Forward Network (FFN)

Given input of shape [B, S, H]:

  1. First linear layer: [B, S, H] × [H, 4H] = [B, S, 4H]. Computational cost: 4 × B × S × H² FLOPs.
  2. Second linear layer: [B, S, 4H] × [4H, H] = [B, S, H]. Computational cost: 4 × B × S × H² FLOPs.

Total FLOPs for FFN: 16 × B × S × H².

Total Computational Cost

Forward Pass per Layer: 24 × B × S × H² + 4 × B × H × S² FLOPs.

L Layers: L × (24 × B × S × H² + 4 × B × H × S²).

Generation Step: 2 × B × S × H × V FLOPs.

Backpropagation involves computing gradients, effectively doubling the forward pass cost. Therefore, processing all parameters across all tokens in one epoch requires approximately 6 FLOPs per parameter.

Example: LLaMA 65B Model

Model Parameters: 65 × 10⁹. Training Tokens: 1.4 × 10¹².

Required FLOPs ≈ 6 × (Parameters × Tokens).

Practical FLOPs = Number of GPUs × Single GPU FLOPs × GPU Utilization. For instance: 2048 GPUs × 312 TFLOPS (A100) × 10¹² × 0.45 utilization.

Training Time = Required FLOPs / Practical FLOPs.

Memory Footprint

Models are typically loaded in 16-bit precision (e.g., BF16), where each parameter occupies 2 bytes.

For a 7B parameter model, static weight memory is approximately 14 GB.

During training with AdamW and mixed precision, memory per parameter includes:

  • 2 bytes for model weights (16-bit).
  • 2 bytes for weight gradients (16-bit).
  • 4 bytes for optimizer first-moment states (32-bit).
  • 4 bytes for optimizer second-moment states (32-bit).

Total: ~16 bytes per parameter.

Overall memory = Model parameters + Gradients + Optimizer states + Activations.

Tokenizer

Common tokenization methods include Byte Pair Encoding (BPE), Byte-level BPE (BBPE), Unigram Language Model (ULM), and WordPiece.

Positional Encoding

Since the attentoin mechanism lacks inherent positional awareness, positional encodings are added to token embeddings. In the original Transformer, sinusoidal positional encodings are used, forming a matrix of shape [B, S, H] with values between 0 and 1.

Rotary Positional Encoding (RoPE)

RoPE directly encodes relative positional information during attention score computation by applying a rotation matrix based on token positions.

Attention Inference Optimization

KV Cache

KV cache optimization focuses on sparsity, quantization, memory allocation, windowing, and sharing. During autoregressive generation, past key and value states are cached to avoid recomputation.

# Example of using past_key_values for incremental generation
output_step2 = model.forward(
    input_ids=torch.tensor([[3837]], dtype=torch.long),
    attention_mask=torch.tensor([[1]], dtype=torch.long),
    past_key_values=output_step1.past_key_values,
)

Multi-Query Attention (MQA) and Grouped Query Attention (GQA)

MQA shares key and value heads across attention heads, reducing memory and computation. GQA groups heads to share key and value projections, offering a balance between efficiency and quality.

def repeat_kv(tensor: torch.Tensor, repetitions: int) -> torch.Tensor:
    batch, kv_heads, seq_len, head_dim = tensor.shape
    if repetitions == 1:
        return tensor
    expanded = tensor[:, :, None, :, :].expand(batch, kv_heads, repetitions, seq_len, head_dim)
    return expanded.reshape(batch, kv_heads * repetitions, seq_len, head_dim)

Sliding Window Attention (SWA)

Limits attention to a fixed window around each token, reducing computation from O(S²) to O(S × W) where W is the window size.

FlashAttention

Optimizes attention by reducing memory access overhead through tiling and recomputation techniques.

PagedAttention

Manages KV cache memory in non-contiguous pages, improving memory utilization during long sequence generation.

Quantization

Reduces model weight precision (e.g., from 16-bit to 8-bit or 4-bit) to decrease memory usage and accelerate inference.

Decoder-Only Inference

Techniques like continuous batching and speculative decoding improve throughput and reduce latency.

Pretraining Tasks

Masked Language Modeling (MLM)

Used in models like BERT, where random tokens are masked and the model predicts them.

class MaskedLanguageModel(torch.nn.Module):
    def __init__(self, hidden_size, vocab_size):
        super().__init__()
        self.linear = torch.nn.Linear(hidden_size, vocab_size)
        self.softmax = torch.nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))

Causal Language Modeling

Used in GPT-style models, predicting the next token in a sequence.

Downstream Tasks

Text Classification / Sentence Embedding

Utilizes the final hidden states ([batch, sequence_length, hidden_size]). Pooling (e.g., CLS token pooling) produces fixed-size representations.

Named Entity Recognition (NER)

Token-level classification where tokenizer offsets map characters to token labels.

# Example model head for token classification
class TokenClassificationModel(BertPreTrainedModel):
    ...

Question Answering

Predicts start and end positions of answer spans within a context.

# Example model head for QA
class QAModel(BertPreTrainedModel):
    ...

Tags: Transformer Attention kv-cache RoPE FlashAttention

Posted on Wed, 24 Jun 2026 17:35:17 +0000 by Sul