Every neural network begins its life as a collection of randomly initialized weights. The specific values those weights start with might seem like a trivial detail compared to architecture design or training data, but poor initialization can make the difference between a model that learns in hours and one that never learns at all. Weight initialization is one of those foundational topics that every deep learning practitioner needs to understand.
What Goes Wrong with Naive Initialization
To appreciate why initialization matters, consider what happens when you get it wrong. There are two common failure modes:
All Zeros
If you initialize all weights to zero, every neuron in a layer computes the same output. During backpropagation, they all receive the same gradient and update identically. This symmetry problem means the network effectively has only one neuron per layer, regardless of how wide the layers are. The model cannot learn diverse features.
Random with Wrong Scale
If weights are drawn from a distribution that is too wide (too large variance), activations grow exponentially as they pass through layers. This leads to exploding gradients, where gradients become astronomically large, causing weight updates that overshoot wildly and training becomes unstable.
Conversely, if the distribution is too narrow (too small variance), activations shrink toward zero through successive layers. This causes vanishing gradients, where gradients become negligibly small, and the early layers barely learn at all.
The goal of proper initialization is to maintain roughly the same variance of activations and gradients across all layers, preventing both explosion and vanishing.
Key Takeaway
Poor initialization causes either vanishing or exploding activations through deep networks. Proper initialization keeps the signal flowing at a stable magnitude through all layers.
Xavier (Glorot) Initialization
In 2010, Xavier Glorot and Yoshua Bengio derived the conditions needed to maintain stable activation and gradient variance through a network with symmetric activations like tanh or sigmoid. Their Xavier initialization draws weights from a distribution with variance:
Var(W) = 2 / (n_in + n_out)
where n_in is the number of input connections and n_out is the number of output connections for each neuron. This can be implemented as either a uniform or normal distribution:
- Uniform: W ~ Uniform(-sqrt(6/(n_in+n_out)), sqrt(6/(n_in+n_out)))
- Normal: W ~ Normal(0, sqrt(2/(n_in+n_out)))
The intuition is elegant: by accounting for both the fan-in and fan-out of each layer, Xavier initialization ensures that the variance of activations remains approximately constant during the forward pass, and the variance of gradients remains approximately constant during the backward pass.
He (Kaiming) Initialization
Xavier initialization was designed for symmetric activation functions. When ReLU activations became dominant, Kaiming He et al. showed in 2015 that Xavier initialization produced suboptimal results because ReLU zeros out half the activations, effectively halving the variance at each layer.
He initialization compensates by doubling the variance:
Var(W) = 2 / n_in
Note that He initialization only considers the fan-in, not the fan-out. The factor of 2 in the numerator compensates for ReLU's effect of zeroing approximately half the neurons.
In PyTorch, He initialization is the default for most layers and is available as:
# He normal initialization
torch.nn.init.kaiming_normal_(tensor, mode='fan_in', nonlinearity='relu')
# He uniform initialization
torch.nn.init.kaiming_uniform_(tensor, mode='fan_in', nonlinearity='relu')
Practical Guidelines
Here is a quick reference for choosing initialization methods:
- ReLU, Leaky ReLU, ELU: Use He (Kaiming) initialization
- Tanh, Sigmoid: Use Xavier (Glorot) initialization
- SELU: Use LeCun initialization (Var = 1/n_in), specifically designed for self-normalizing networks
- Batch Normalization: Largely reduces sensitivity to initialization, but He initialization remains a good default
- Transformers: Typically use custom scaled initialization, often with weights scaled by 1/sqrt(d_model) or 1/sqrt(2*n_layers)
The Role of Normalization
Batch normalization, layer normalization, and other normalization techniques have reduced (but not eliminated) the importance of careful initialization. These methods normalize activations at each layer, preventing the variance from drifting too far. However, even with normalization, poor initialization can still slow convergence or cause instabilities in the early stages of training.
Key Takeaway
Use He initialization for ReLU-based networks, Xavier for tanh/sigmoid, and scaled initialization for transformers. Normalization layers help but do not eliminate the need for thoughtful initialization.
Initialization in Modern Architectures
Large language models and modern transformers use specialized initialization strategies. GPT-style models typically initialize residual layers with a scaling factor of 1/sqrt(2*n_layers) to prevent the residual connections from accumulating too much variance through the depth of the network.
The muP (maximal update parameterization) framework takes this further, providing a principled way to initialize and scale learning rates so that hyperparameters transfer from small models to large ones. This is crucial when training costs millions of dollars: you want to find good hyperparameters on a small model and have them work on the full-scale model.
For fine-tuning pretrained models, initialization of new layers (such as a classification head) matters while the pretrained weights serve as their own initialization. Research on LoRA and other parameter-efficient fine-tuning methods shows that the initialization of adapter weights significantly affects convergence speed and final performance.
Weight initialization may not be glamorous, but it is one of those foundational details that separates practitioners who can reliably train deep networks from those who struggle with mysterious convergence failures. Understanding the principles behind proper initialization saves countless hours of debugging.
