Neural Networks BatchNorm
γ, β train / eval

Batch Normalization

Batch normalization (BN) standardizes the inputs to each layer using statistics computed over the mini-batch during training, then applies a learnable affine transform (scale γ and shift β) so the network can recover representational power if identity is optimal. It tends to stabilize gradients, allow higher learning rates, and acts as a mild regularizer because each example’s normalization depends on other examples in the batch.

running mean eval mode CNN PyTorch

What Batch Norm Does

For a tensor of activations, BN computes mean and variance across the normalization axes (for fully connected layers, often over the batch dimension; for conv layers, over batch and spatial dims per channel). It then transforms x̂ = (x − μ) / √(σ² + ε) and outputs y = γ x̂ + β. The small ε avoids division by zero.

The original paper motivated BN as reducing internal covariate shift—the change in input distribution to layers as parameters update. Whether that story is the full explanation remains debated; empirically BN often smooths the loss landscape and improves optimization in CNNs.

Training vs Inference

During training, μ and σ² come from the current batch. During inference, batch statistics would be noisy for batch size 1; frameworks maintain exponential moving averages of mean and variance updated during training and use those frozen values at test time.

In PyTorch, call model.eval() before validation or deployment so BatchNorm and Dropout switch behavior. Forgetting this is a classic source of “validation accuracy much worse than training” even when the model is fine.

Small batches. With very small batch size, batch statistics are high-variance; consider GroupNorm or LayerNorm in those regimes, or accumulate statistics carefully.

Where to Place BN

Common pattern in CNNs: Conv → BatchNorm → ReLU. For linear layers, Linear → BatchNorm → activation. Some architectures use BN before activation; others after—consistency within a model matters more than dogma, but follow the reference implementation when reproducing papers.

BN interacts with weight decay: some practices decouple BN’s γ, β from L2; PyTorch’s AdamW and parameter groups help you exclude biases and BN affine from decay if desired.

PyTorch: BatchNorm1d / BatchNorm2d

MLP and conv blocks
import torch.nn as nn

mlp_block = nn.Sequential(
    nn.Linear(512, 256),
    nn.BatchNorm1d(256),
    nn.ReLU(),
)

conv_block = nn.Sequential(
    nn.Conv2d(3, 64, 3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
)

Summary

  • BN normalizes activations per batch (per channel in conv) then applies learnable γ, β.
  • Training uses batch stats; inference uses running averages—toggle with train()/eval().
  • Often improves optimization for CNNs; small-batch settings may prefer GroupNorm/LayerNorm.
  • Placement and weight-decay handling should match your baseline architecture.

Next: when the model fits too well—overfitting—and how to recognize it on learning curves.