Lookup Table Mechanics
In neural networks for sequence processing, the nn.Embedding module functions as a searchable dictionary. It translates discrete integer identifiers into continuous high-dimensional vectors. Rather than requiring sparse one-hot representations as inputs, this layer dircetly accepts integer indices to retrieve their corresponding dense weight rows. The retrieved vectors are optimized during backpropagation, refining the semantic representation of each index.
import torch
import torch.nn as nn
num_tokens = 5
feature_size = 4
lookup = nn.Embedding(num_tokens, feature_size)
print(lookup.weight)
# Output reveals a (5, 4) parameter matrix
target_idx = torch.tensor([3])
retrieved_vec = lookup(target_idx)
print(retrieved_vec)
# Retrieves the vector at row 3 of the weight matrix
Distinction from Linear Layers
While both modules project inputs into a different dimensional space, their underlying mechanisms and weight configurations differ significantly.
Weight Orientation
An nn.Linear(in_dim, out_dim) layer stores its parameters in a transposed layout compared to an nn.Embedding(num_items, dim). The linear layer's weight matrix possesses the shape (out_dim, in_dim), whereas the embedding matrix aligns directly as (num_items, dim).
Computational Approach
The linear layer evaluates mathematical transformations through matrix multiplication across the entire input tensor. In contrast, the embedding module bypasses arithmetic operations entirely, functioning purely as a direct memory address lookup to slice specific rows from its parameter tensor.
import torch
import torch.nn as nn
linear_proj = nn.Linear(5, 4, bias=False)
print(linear_proj.weight.shape)
# Shape: (4, 5)
dense_input = torch.ones(5)
linear_out = linear_proj(dense_input)
# Computes: dense_input @ linear_proj.weight.T
embed_proj = nn.Embedding(5, 4)
print(embed_proj.weight.shape)
# Shape: (5, 4)
sparse_input = torch.tensor([2])
embed_out = embed_proj(sparse_input)
# Retrieves: embed_proj.weight[2]