Training Deep Networks
Why do we need batch normalization layers? Let us review some practical challenges that arise when training neural networks.
First, the way data are preprocessed often dramatically influences the final result. Recall the example of using a multilayer perceptron to predict house prices. When working with real data, our first step was to standardize input features to have zero mean and unit variance. Intuitively, this standardization works well with our optimizers because it places the parameters on a comparable scale.
Second, for a typical multilayer perceptron or convolutional neural network, the variables in intermediate layers (e.g., affine transformation outputs in an MLP) may exhibit a wide range of magnitudes: across layers from input to output, across units within the same layer, or over time as the model parameters update unpredictably during training. The inventors of batch normalization informally hypothesized that such shifts in the distribution of these variables could hamper network convergence. Intuitively, we might suspect that if the values of one layer are 100 times larger than those of another, this may require compensating adjustments in the learning rate.
Third, deeper networks are complex and prone to overfitting. This means that regularization becomes more important.
Batch normalization is applied to individual layers (optionally to all layers). The principle works as follows: In each training iteration, we first normalize the inputs by subtracting their mean and dividing by their standard deviation, both computed over the current minibatch. Next, we apply a scaling coefficient and a shifting offset. It is precisely because of this standardization based on batch statistics that the method is named batch normalization.
Note that if we attempt to apply batch normalization with a minibatch of size 1, we would learn nothing. After subtracting the mean, every hidden unit would become zero. Therefore, batch normalization is effective and stable only when using a sufficiently large minibatch. When applying batch normalization, the choice of batch size may be even more critical than without it.
Formally, let \(\mathbf{x} \in \mathcal{B}\) denote an input from a minibatch \(\mathcal{B}\). Batch normalization \(\mathrm{BN}\) transforms \(\mathbf{x}\) according to the following expression:
Since intermediate layers cannot be allowed to vary too wildly in magnitude during training, batch normalization proactively centers each layer and rescales them to a given mean and size (via \(\hat{\boldsymbol{\mu}}_\mathcal{B}\) and \({\hat{\boldsymbol{\sigma}}_\mathcal{B}}\)).
Formally, we compute \(\hat{\boldsymbol{\mu}}_\mathcal{B}\) and \({\hat{\boldsymbol{\sigma}}_\mathcal{B}}\) in (7.5.1) as follows:
This turns out to be a recurring theme in deep learning. For reasons that are not yet clearly understood theoretically, various sources of noise in optimization often lead to faster training and less overfitting: this variation appears to act as a form of regularization. In some preliminary studies, (Teye et al., 2018) and (Luo et al., 2018) have respectively linked properties of batch normalization to Bayesian priors. These theories shed light on why batch normalization is most suitable for moderate minibatch sizes in the range of \(50 \sim 100\).
Furthermore, batch normalization layers behave differently in "training mode" (normalizing by minibatch statistics) and "prediction mode" (normalizing by dataset statistics). During training, we cannot use the entire dataset to estimate the mean and variance, so we continuously train the model using the mean and variance from each minibatch. In prediction mode, we can accurately compute the mean and variance required for batch normalization using the entire dataset.
Now, let us examine how batch normalization works in practice.
Batch Normalization Layers
Recall that a key difference between batch normalization and other layers is that batch normalization operates on full minibatches. Consequently, we cannot ignore the batch size as we did when introducing other layers. We discuss the two cases below: fully connected layers and convolutional layers. Their batch normalization implementations differ slightly.
Fully Connected Layers
Typically, we place the batch normalization layer between the affine transformation and the activation function in a fully connected layer. Let the input of a fully connected layer be \(\mathbf{x}\), the weight parameter be \(\mathbf{W}\), the bias parameter be \(\mathbf{b}\), the activation function be \(\phi\), and the batch normalization operator be \(\mathrm{BN}\). Then the computation of the output of a fully connected layer using batch normalization is as follows:
Convolutional Layers
Similarly, for convolutional layers, we can apply batch normalization after the convolution and before the nonlinear activation function. When a convolution has multiple output channels, we must perform batch normalization on each output channel, each with its own scale and shift parameters, which are both scalars. Assume our minibatch contains \(m\) samples, and for each channel the output of the convolution has height \(p\) and width \(q\). Then for a convolutional layer, we simultaneously perform batch normalization on the \(m \cdot p \cdot q\) elements per output channel. Therefore, when computing the mean and variance, we aggregate values over all spatial locations and then apply the same mean and variance within a given channel to normalize the values at each spatial position.
Batch Normalization During Prediction
As mentioned earlier, batch normalization typically behaves differently in training mode and prediction mode. First, when using a trained model for prediction, we no longer need the noise from sample means and the sample variances estimated from each minibatch. Second, we may need to use our model to predict one sample at a time. A common approach is to estimate the sample mean and variance of the entire training dataset using a moving average and use these stable estimates during prediction to obtain deterministic outputs. Thus, like dropout, batch normalization layers compute different results in training mode and prediction mode.
Implementation from Scratch
Below, we implement a batch normalization layer with tensors from scratch.
import torch
from torch import nn
from d2l import torch as d2l
def batch_norm_forward(X, scale, shift, running_mean, running_var, eps, momentum):
# Determine whether we are in training or prediction mode
if not torch.is_grad_enabled():
# In prediction mode, use the running averages
X_norm = (X - running_mean) / torch.sqrt(running_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# Fully connected case: mean and variance over features
batch_mean = X.mean(dim=0)
batch_var = ((X - batch_mean) ** 2).mean(dim=0)
else:
# 2D convolutional case: mean and variance over channels (axis=1)
# Keep dimensions for broadcasting
batch_mean = X.mean(dim=(0, 2, 3), keepdim=True)
batch_var = ((X - batch_mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
# Normalize using current batch statistics (training mode)
X_norm = (X - batch_mean) / torch.sqrt(batch_var + eps)
# Update running statistics
running_mean = momentum * running_mean + (1.0 - momentum) * batch_mean
running_var = momentum * running_var + (1.0 - momentum) * batch_var
Y = scale * X_norm + shift # Scale and shift
return Y, running_mean.data, running_var.data
We can now create a proper BatchNorm layer. This layer will maintain the appropriate parameters: the scale gamma and shift beta, which will be updated during training. Additionally, our layer will keep moving averages of the mean and variance for use during model prediction.
Setting aside algorithmic details, notice the fundamental design pattern for implementing a layer. Typically, we define its mathematical logic in a separate function, such as batch_norm. We then integrate this functionality into a custom layer whose code mainly handles moving data to the training device (e.g., GPU), allocating and initializing any required variables, and tracking moving averages (mean and variance here). For convenience, we do not worry about automatically inferring the input shape, so we need to specify the total number of features. The batch normalization APIs in deep learning frameworks will handle these issues for us, as we will demonstrate later.
class BatchNorm(nn.Module):
# num_features: number of outputs for a fully connected layer or output channels for a convolutional layer
# num_dims: 2 for fully connected, 4 for convolutional
def __init__(self, num_features, num_dims):
super().__init__()
if num_dims == 2:
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
# Learnable scale and shift parameters, initialized to 1 and 0
self.scale = nn.Parameter(torch.ones(shape))
self.shift = nn.Parameter(torch.zeros(shape))
# Non-parameter buffers initialized to 0 and 1
self.running_mean = torch.zeros(shape)
self.running_var = torch.ones(shape)
def forward(self, X):
# Move running statistics to the same device as X if necessary
if self.running_mean.device != X.device:
self.running_mean = self.running_mean.to(X.device)
self.running_var = self.running_var.to(X.device)
# Save updated running_mean and running_var
Y, self.running_mean, self.running_var = batch_norm_forward(
X, self.scale, self.shift, self.running_mean, self.running_var,
eps=1e-5, momentum=0.9)
return Y
LeNet with Batch Normalization
To understand how to apply BatchNorm, we apply it to the LeNet model. Recall that batch normalization is applied after the convolutional or fully connected layers and before the corresponding activation function.
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
nn.Linear(16 * 4 * 4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
nn.Linear(84, 10)
)
As before, we will train the network on the Fashion-MNIST dataset. The code is almost identical to when we first trained LeNet, with the main difference being a significantly larger learning rate.
lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.273, train acc 0.898, test acc 0.801
20734.6 examples/sec on cuda:0

Let us examine the scale parameter gamma and shift parameter beta learned from the first batch normalization layer.
net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,))
(tensor([3.4891, 2.9934, 1.6090, 3.6851, 3.2831, 2.1519], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>),
tensor([-3.6648, 2.0706, -1.5978, -0.8258, -2.1900, -0.4488], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>))
Concise Implementation
Besides using the BatchNorm we just defined, we can directly use the BatchNorm provided by the deep learning framework. The code looks almost identical to the one above.
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
nn.Linear(84, 10))
Now, we train the model with the same hyperparameters. Note that the high-level API variant runs significantly faster because its code is compiled into C++ or CUDA, whereas our custom implementation is written in Python.
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.262, train acc 0.903, test acc 0.880
40532.7 examples/sec on cuda:0

Controversy
Intuitively, batch normalization is believed to make optimization smoother. However, we must carefully distinguish intuition from genuine explanations of observed phenomena. Recall that we do not even fully understand why simple neural networks (MLPs and standard CNNs) work so effectively. Even with dropout and weight decay, they remain highly flexible, and conventional learning theory generalization guarantees cannot explain whether they can generalize to unseen data.
In the original paper that proposed batch normalization, the authors not only introduced its application but also explained its rationale: by reducing internal covariate shift. Presumably, the internal covariate shift the authors referred to resembles the speculative intuition described above—that the distribution of variable values changes during training. However, this explanation has two problems: 1) this shift is very different from the rigorously defined covariate shift, making the term a misnomer; 2) the explanation only provides an ambiguous intuition and leaves an open question for subsequent investigation: why is this technique so effective?
As batch normalization grew in popularity, the internal covariate shift explanation repeatedly appeared in debates within the technical literature, particularly in broader discussions about "how to present machine learning research." In a memorable speech while accepting the Test of Time Award at NeurIPS 2017, Ali Rahimi used internal covariate shift as a focal point and compared modern deep learning practices to alchemy. This example was reviewed in detail (Lipton and Steinhardt, 2018), outlining disturbing trends in machine learning. Furthermore, some authors have proposed alternative explanations for the success of batch normalization: in some respects the behavior of batch normalization contradicts the claims made in the original paper (Santurkar et al., 2018).
Nevertheless, internal covariate shift is no more deserving of criticism than thousands of similarly vague claims in the machine learning literature. It most likely resonated as a focus of those debates thanks to its widespread recognition by the target audience. Batch normalization has proven to be an indispensable method. It is used in nearly every image classifier and has garnered tens of thousands of citations in academia.
Summary
- During model training, batch normalization continuously adjusts the intermediate outputs of the neural network using the mean and standard deviation of minibatches, making the output values of every layer more stable.
- Batch normalization is applied slightly differently in fully connected layers and convolutional layers.
- Like dropout, batch normalization layers compute differently in training mode and prediction mode.
- Batch normalization has many beneficial side effects, primarily regularization. On the other hand, the original motivation of "reducing internal covariate shift" does not appear to be a valid explanation.