trunc_normal_的理解
在vit源码中,可以看见作者使用以下代码对位置编码和类别token的正态分布参数进行截断,但只是用标准差为0.02的值进行截断。
# 定义类别token和位置编码
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
# 对参数正态分布截断
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
为何选用这么小的标准差的正态分布来初始化张量,而不是标准差为1的标准正态分布。这样做的目的是控制初始化值的范围,使其更加集中,以确保模型在训练初期的稳定性。
- 避免梯度爆炸和消失:
在深层神经网络中,如果初始化的权重太大,经过多层的前向传播和反向传播,梯度可能会逐渐变大,导致梯度爆炸。反之,如果初始化的权重太小,梯度可能会逐渐变小,导致梯度消失。
使用较小的方差(标准差)初始化权重可以帮助保持梯度在合理范围内,从而有助于稳定训练过程。 - 帮助网络更快地收敛:
较小的初始化权重可以确保在训练初期网络输出和梯度在合理的范围内,这样有助于梯度下降法更快地找到优化路径,从而加速收敛。 - 防止过拟合:
较小的权重初始化可以减少模型在训练初期的复杂度,防止模型过度拟合训练数据。这有助于提高模型的泛化能力。
理论讲解:假设有个多层感知机(MLP),每一层的输入是前一层的输出。假设第
l
l
l的输入为
x
l
x^l
xl,输出为
y
l
y^l
yl,权重矩阵为
W
l
W^l
Wl,那么
y
l
=
W
l
x
l
+
b
l
y^l=W^lx^l+b^l
yl=Wlxl+bl
如果权重初始化太大,经过多层线性变化后,输出值会迅速增大,导致激活函数(如ReLu, sigmoid)进入饱和区,产生饱和或梯度消失问题;如果权重初始化过小,输出值会接近零,导致信号在网络中逐层减弱;
通常将标准差设置为0.02,可以保证初始化的权重较小,稳定梯度回传。
参数初始化方法
包含Xavier初始化和He初始化
Xavier初始化:适合sigmoid和tanh激活函数,权重的方差为:
2
f
a
n
i
n
+
f
a
n
o
u
t
\frac{2}{fan_{in}+fan_{out}}
fanin+fanout2
He初始化:适合ReLu激活函数及其变体,权重的方差为:
2
f
a
n
i
n
\frac{2}{fan_{in}}
fanin2
代码可视化,使用He初始化权重:
代码如下:
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
import matplotlib.pyplot as plt
import numpy as np
# 定义一个简单的全连接层
class SimpleModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(input_dim, output_dim)
self.initialize_weights()
def initialize_weights(self):
fan_in = self.fc.weight.size(1)
std = np.sqrt(2.0 / fan_in)
nn.init.trunc_normal_(self.fc.weight, std=std)
def forward(self, x):
return self.fc(x)
# 定义输入和输出的维度
input_dim = 256
output_dim = 10
# 创建模型实例
model = SimpleModel(input_dim, output_dim)
# 可视化初始化的权重分布
weights = model.fc.weight.data.numpy().flatten()
plt.hist(weights, bins=100, density=True)
plt.title(f'Histogram of Initialized Weights with He Initialization (std={weights.std():.4f})')
plt.xlabel('Weight Value')
plt.ylabel('Frequency')
plt.show()
# 打印初始化权重的统计信息
print(f"Mean: {weights.mean():.4f}, Std: {weights.std():.4f}")