The Fundamental Concept of Generalization
Generalization represents one of the most critical concepts in machine learning. It describes how well a model trained on specific datasets performs when encountering previously unseen data.
In machine learning workflows, we utilize training datasets to develop models that establish mapping relationships between inputs and outputs. However, our primary objective extends beyond achieving satisfactory predictions on training data—we aim for models that demonstrate strong performance on novel, unencountered datasets.
Model generalization capability serves as a crucial evaluation metric. A model with superior generalization should deliver accurate predictions on unseen data, rather than merely excelling with training samples.
When models perform exceptionally well on training data but poorly on new data, this phenomenon is termed overfitting. Overfitting occurs when models learn specific patterns from training data too thoroughly, adapting excessively to these particular patterns while overlooking other potential patterns and relationships. Conversely, when models underperform on both training and new data, this condition is called underfitting.
To minimize both overfitting and underfitting risks, practitioners can implement various techniques including cross-validation, regularization, and feature selection. These approaches aim to enhance model generalization to unseen data, thereby improving predictive capabilities.
Core Challenge: Optimization Versus Generalization
Through previous examples involving sentiment analysis, news categorization, and price prediction, data was typically divided into training, validation, and test sets.
The rationale for avoiding evaluation on identical training data becomes apparent: after just several iterations, model performance on unseen data begins diverging from training data performance, with the latter consistently improving during training. Models begin overfitting, a challenge present across all machine learning problems.
The fundamental issue in machine learning lies in the tension between optimizaton and generalization.
Optimization involves adjusting models to achieve optimal performance on training data (corresponding to learning in machine learning), while generalization refers to model performance on previously unseen data.
While the ultimate goal involves achieving excellent generalization, direct control over generalization remains impossible—models can only be fitted to training data. Excessive fitting leads to overfitting, which negatively impacts generalization.
Underfitting Versus Overfitting Dynamics
As training progresses, model performance on held-out validation data initially improves, then inevitably peaks before declining.
This pattern proves highly consistent across all models and datasets.
During initial training phases, optimization and generalization correlate positively: lower training losses correspond to lower test losses. At this stage, models exhibit underfitting—they retain room for improvement and haven't yet modeled all relevant patterns in training data.
However, after sufficient training iterations on training data, generalization capacity ceases improving. Validation metrics first remain stable, then deteriorate. Models begin overfitting, learning patterns specific to training data that prove incorrect or irrelevant for new data.
Noisy or rare characteristics in datasets particularly increase overfitting susceptibility.
Practical Examples of Data Challenges
Noisy Training Datasets
In real-world datasets, unidentifiable inputs occur frequently. For example, MNIST digit images might include completely black images or unusual specimens. What numbers do these represent? The answer often remains unclear, yet they form part of MNIST training sets. More problematic scenarios involve valid inputs with incorrect labels.
When models account for all these anomalies, their generalization performance declines. If a handwritten '4' appears similar to incorrectly labeled '4's shown previously, it likely gets classified as '9'.
Ambiguous Features
Not all data noise originates from errors—when problems contain uncertainty and ambiguity, even perfectly clean and well-labeled data may exhibit noise characteristics.
Classification tasks often encounter regions in input feature spaces associated with multiple categories simultaneously. Consider developing a model that receives banana images as input and predicts whether bananas are unripe, ripe, or spoiled. These categories lack clear boundaries; the same image might receive different labels from different people. Similarly, many problems contain randomess. Weather prediction using barometric data faces scenarios where identical measurements might result in either rain or sunshine with certain probabilities.
Models may become overly confident about uncertain regions in feature space, leading to overfitting on probabilistic data.
Rare Features and Spurious Correlations
If you've only encountered two orange tabby cats in your life, both unfriendly, you might conclude that orange tabby cats are typically unfriendly—this exemplifies overfitting. Greater cat exposure, including more orange cats, would reveal that fur color bears little correlation with personality.
Similarly, machine learning models easily overfit when trained on datasets containing rare features. For sentiment classification tasks, if 'cherimoya' (a fruit native to the Andes) appears in only one training text with negative sentiment, an inadequately regularized model might assign high weight to this term, consistently classifying new texts mentioning cherimoya as negative. Objectively, the word 'cherimoya' contains no inherent negative sentiment.
Importantly, frequently occurring feature values also create spurious correlations. When a word appears in 100 training samples with 54% positive and 46% negative associations, this difference likely represents statistical coincidence, yet models often learn to utilize such features for classification—a common overfitting source.
Consider this striking example:
For the MNIST dataset, connecting 784 white noise dimensions to existing 784 data dimensions creates a new training set where half the data consists of noise. An equivalent dataset connects 784 zero-value dimensions:
import tensorflow
from tensorflow.keras.datasets import mnist
import numpy as np
(train_images, train_labels), _ = mnist.load_data()
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
train_images_with_noise_features = np.concatenate(
[train_images, np.random.random((len(train_images), 784))], axis=1)
train_images_with_zero_features = np.concatenate(
[train_images, np.zeros((len(train_images), 784))], axis=1)
Training identical models on both datasets:
from tensorflow import keras
from tensorflow.keras import layers
def build_network():
network = keras.Sequential([
layers.Dense(512, activation="relu"),
layers.Dense(10, activation="softmax")
])
network.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
return network
model = build_network()
history_noise = model.fit(
train_images_with_noise_features, train_labels,
epochs=10,
batch_size=128,
validation_split=0.2)
model = build_network()
history_zeros = model.fit(
train_images_with_zero_features, train_labels,
epochs=10,
batch_size=128,
validation_split=0.2)
Comparing validation accuracy changes over time:
import matplotlib.pyplot as plt
val_acc_noise = history_noise.history["val_accuracy"]
val_acc_zeros = history_zeros.history["val_accuracy"]
epochs = range(1, 11)
plt.plot(epochs, val_acc_noise, "r-",
label="Validation accuracy with noise features")
plt.plot(epochs, val_acc_zeros, "r--",
label="Validation accuracy with zero features")
plt.title("Impact of noise features on validation accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
Noise features inevitably cause overfitting. When uncertain about feature utility, practitioners commonly perform feature selection before training. For example, limiting IMDB data to the 10,000 most frequent words represents rough feature selection. Common methods calculate usefulness scores for each feature, retaining only those exceeding threshold values. Usefulness scores measure information content for specific tasks, such as mutual information between features and labels.
The Nature of Deep Learning Generalization
A notable fact about deep learning models: given sufficient representational capacity, models can be trained to fit any data.
Testing this concept, shuffle MNIST labels and train a model on the shuffled dataset:
(train_images, train_labels), _ = mnist.load_data()
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
shuffled_train_labels = train_labels[:]
np.random.shuffle(shuffled_train_labels)
model = keras.Sequential([
layers.Dense(512, activation="relu"),
layers.Dense(10, activation="softmax")
])
model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
model.fit(train_images, shuffled_train_labels,
epochs=100,
batch_size=128,
validation_split=0.2)
Despite no relationship between inputs and shuffled labels, training loss decreases significantly—even with a relatively small model. Validation loss shows no improvement over time since generalization remains impossible in this scenario.
The essence of deep learning generalization relates less to deep learning models themselves and more to information structure in the real world.
Manifold Hypothesis
An MNIST classifier's input (before preprocessing) comprises a 28×28 array of integers ranging 0–255. Total input values equal 256^784, far exceeding atomic counts in the universe. However, only a few inputs appear as valid MNIST samples. In the parent space of all possible 28×28 uint8 arrays, genuine handwritten digits occupy a minuscule subspace. More importantly, this subspace exhibits high structure.
Valid handwritten digit subspaces prove continuous: slightly modifying a sample still identifies it as the same handwritten digit. All samples in effective subspaces connect via smooth paths. Taking two random MNIST digits A and B, intermediate transformation sequences exist where adjacent images remain highly similar.
Ambiguous shapes near category boundaries still resemble digits.
In terminology, handwritten digits constitute a manifold within 28×28 uint8 array possibility spaces. A "manifold" represents a low-dimensional subspace in some parent space, locally approximating linear space (Euclidean space). Smooth curves in planes represent one-dimensional manifolds in two-dimensional space since tangent lines exist at every point. Smooth surfaces in three-dimensional space are two-dimensional manifolds.
More generally, the manifold hypothesis assumes all natural data exists within low-dimensional manifolds of high-dimensional spaces—the data encoding space.
This represents a powerful statement about cosmic information structure. As far as currently known, this statement proves accurate and explains deep learning effectiveness. It applies not only to MNIST handwritten digits but also to tree morphologies, human faces, voices, and natural language.
Interpolation as Generalization Source
Manifold hypothesis implies machine learning models need only fit relatively simple, low-dimensional, highly structured subspaces (potential manifolds) within input spaces. Within manifolds, interpolation between inputs always remains possible—transforming one input into another via continuous paths where all points lie within the manifold.
Interpolation capability between samples proves key to understanding deep learning generalization.
When handling interpolatable data points, understanding unprecedented points becomes possible by connecting them with nearby points in the manifold. Spatial samples enable understanding entire spaces through interpolation to fill gaps.
Crucially, while deep learning achieves generalization through learning approximate interpolation of data manifolds, considering interpolation as generalization's entirety proves incorrect. It represents merely surface-level benefits. Interpolation aids understanding of extremely familiar phenomena—achieving local generalization. How ever, humans effectively handle extreme novelty. Daily experiences differ from previous encounters and human history. Week-long stays in New York, Shanghai, and Bangalore require no city-specific practice sessions.
Humans achieve extreme generalization through cognitive mechanisms distinct from interpolation: abstraction, symbolic world models, reasoning, logic, common sense, and innate prior knowledge about the world—often called reason versus intuition and pattern recognition. The latter essentially interpolatable, the former not. Both prove essential for intelligence.
Why Deep Learning Works
Deep learning models essentially represent high-dimensional curves—smooth, continuous curves requiring differentiability. Through gradient descent, these curves smoothly and gradually fit data points. Fundamentally, deep learning involves taking large, complex curves (manifolds) and gradually adjusting parameters until curves fit training data points.
These curves contain sufficient parameters to fit any data. Given adequate training time, models eventually remember training data entirely without generalization capability.
However, target data doesn't consist of isolated points sparsely distributed throughout underlying spaces. Data forms highly structured low-dimensional manifolds in input spaces—the manifold hypothesis. As gradients descend gradually, model curves smoothly fit this data. During training, intermediate points emerge where models approximate natural data manifolds.
At these intermediate points, movement along learned model curves approximates movement along actual data potential manifolds. Models understand unprecedented inputs through interpolation between training inputs.
Critical Importance of Training Data
While deep learning suits manifold learning, generalization capability stems more from natural data structure than model attributes.
Only when data forms interpolatable manifolds can models generalize. Greater feature information content with less noise enhances generalization since input spaces prove simpler and more structured. Data management and feature engineering prove crucial for generalization.
Additionally, since deep learning involves curve fitting, models require training on dense sampling of input spaces. Dense sampling means training data should densely cover entire input data manifolds, especially near decision boundaries. With sufficiently dense sampling, new inputs become understandable through interpolation between previous training inputs without requiring common sense, abstract reasoning, or external world knowledge—capabilities unavailable to machine learning models.