Model Structure
The basic unit consists of token embedding with positional encoding, encoder, and decoder.
- Encoder: Self-attention layer with skip connections and layer normalization, followed by a feed-forward network (FFN) with skip connections and layer normalization.
- Decoder: Self-attention layer with skip connections and layer normalization, cross-attention layer with skip connections and layer normalization, and a feed-forward network with skip connections and layer normalization.
Computational Complexity and Parameters
Before training a large model, scaling laws are used to estimate data requirements, computational power (FLOPs), and training time.
In deep learning, each forward pass involves matrix multiplications. For a single parameter, this typically requires one addition and one multiplication operation, counted as two floating-point operations (FLOPs).
The following conventions are adopted:
- L: Number of Transformer layers.
- H: Model hidden dimension (d_model).
- h: Number of attention heads.
- B: Batch size.
- S: Sequence length (e.g., 2K for GPT, 4K for LLaMA).
- V: Vocabulary size.
Inference Memory
Memory consumption is primarily from model parameters, KV cache, and intermediate activations. Intermediate results typically contribute a smaller portion: batch size × token length × embedding size.
Attention Mechanism
Given input embeddings of shape [B, S, H]:
- Project Q, K, V matrices via dense layers: [B, S, H] × [H, H] = [B, S, H]. Computational cost: 3 × B × S × H² FLOPs.
- Compute attention scores: softmax(Q × Kᵀ / √(d_model)). Shape: [B, h, S, H'] × [B, h, H', S] = [B, h, S, S]. Computational cost: B × h × S² FLOPs.
- Multiply with V: [B, h, S, S] × [B, h, S, H'] = [B, h, S, H']. Computational cost: B × h × S² × H' FLOPs.
- Project back via output dense layer: [B, h, S, H'] × [H', H'] = [B, S, H]. Computational cost: B × S × H² FLOPs.
Total FLOPs for Attention: 8 × B × S × H² + 4 × B × H × S². The factor of 2 accounts for the addition and multiplication per operation.
Feed-Forward Network (FFN)
Given input of shape [B, S, H]:
- First linear layer: [B, S, H] × [H, 4H] = [B, S, 4H]. Computational cost: 4 × B × S × H² FLOPs.
- Second linear layer: [B, S, 4H] × [4H, H] = [B, S, H]. Computational cost: 4 × B × S × H² FLOPs.
Total FLOPs for FFN: 16 × B × S × H².
Total Computational Cost
Forward Pass per Layer: 24 × B × S × H² + 4 × B × H × S² FLOPs.
L Layers: L × (24 × B × S × H² + 4 × B × H × S²).
Generation Step: 2 × B × S × H × V FLOPs.
Backpropagation involves computing gradients, effectively doubling the forward pass cost. Therefore, processing all parameters across all tokens in one epoch requires approximately 6 FLOPs per parameter.
Example: LLaMA 65B Model
Model Parameters: 65 × 10⁹. Training Tokens: 1.4 × 10¹².
Required FLOPs ≈ 6 × (Parameters × Tokens).
Practical FLOPs = Number of GPUs × Single GPU FLOPs × GPU Utilization. For instance: 2048 GPUs × 312 TFLOPS (A100) × 10¹² × 0.45 utilization.
Training Time = Required FLOPs / Practical FLOPs.
Memory Footprint
Models are typically loaded in 16-bit precision (e.g., BF16), where each parameter occupies 2 bytes.
For a 7B parameter model, static weight memory is approximately 14 GB.
During training with AdamW and mixed precision, memory per parameter includes:
- 2 bytes for model weights (16-bit).
- 2 bytes for weight gradients (16-bit).
- 4 bytes for optimizer first-moment states (32-bit).
- 4 bytes for optimizer second-moment states (32-bit).
Total: ~16 bytes per parameter.
Overall memory = Model parameters + Gradients + Optimizer states + Activations.
Tokenizer
Common tokenization methods include Byte Pair Encoding (BPE), Byte-level BPE (BBPE), Unigram Language Model (ULM), and WordPiece.
Positional Encoding
Since the attentoin mechanism lacks inherent positional awareness, positional encodings are added to token embeddings. In the original Transformer, sinusoidal positional encodings are used, forming a matrix of shape [B, S, H] with values between 0 and 1.
Rotary Positional Encoding (RoPE)
RoPE directly encodes relative positional information during attention score computation by applying a rotation matrix based on token positions.
Attention Inference Optimization
KV Cache
KV cache optimization focuses on sparsity, quantization, memory allocation, windowing, and sharing. During autoregressive generation, past key and value states are cached to avoid recomputation.
# Example of using past_key_values for incremental generation
output_step2 = model.forward(
input_ids=torch.tensor([[3837]], dtype=torch.long),
attention_mask=torch.tensor([[1]], dtype=torch.long),
past_key_values=output_step1.past_key_values,
)
Multi-Query Attention (MQA) and Grouped Query Attention (GQA)
MQA shares key and value heads across attention heads, reducing memory and computation. GQA groups heads to share key and value projections, offering a balance between efficiency and quality.
def repeat_kv(tensor: torch.Tensor, repetitions: int) -> torch.Tensor:
batch, kv_heads, seq_len, head_dim = tensor.shape
if repetitions == 1:
return tensor
expanded = tensor[:, :, None, :, :].expand(batch, kv_heads, repetitions, seq_len, head_dim)
return expanded.reshape(batch, kv_heads * repetitions, seq_len, head_dim)
Sliding Window Attention (SWA)
Limits attention to a fixed window around each token, reducing computation from O(S²) to O(S × W) where W is the window size.
FlashAttention
Optimizes attention by reducing memory access overhead through tiling and recomputation techniques.
PagedAttention
Manages KV cache memory in non-contiguous pages, improving memory utilization during long sequence generation.
Quantization
Reduces model weight precision (e.g., from 16-bit to 8-bit or 4-bit) to decrease memory usage and accelerate inference.
Decoder-Only Inference
Techniques like continuous batching and speculative decoding improve throughput and reduce latency.
Pretraining Tasks
Masked Language Modeling (MLM)
Used in models like BERT, where random tokens are masked and the model predicts them.
class MaskedLanguageModel(torch.nn.Module):
def __init__(self, hidden_size, vocab_size):
super().__init__()
self.linear = torch.nn.Linear(hidden_size, vocab_size)
self.softmax = torch.nn.LogSoftmax(dim=-1)
def forward(self, x):
return self.softmax(self.linear(x))
Causal Language Modeling
Used in GPT-style models, predicting the next token in a sequence.
Downstream Tasks
Text Classification / Sentence Embedding
Utilizes the final hidden states ([batch, sequence_length, hidden_size]). Pooling (e.g., CLS token pooling) produces fixed-size representations.
Named Entity Recognition (NER)
Token-level classification where tokenizer offsets map characters to token labels.
# Example model head for token classification
class TokenClassificationModel(BertPreTrainedModel):
...
Question Answering
Predicts start and end positions of answer spans within a context.
# Example model head for QA
class QAModel(BertPreTrainedModel):
...