PyTorch提供了几种常见的参数初始化方式的实现
Xavier Initialization: 基本思想是维持输入和输出的方差一致,避免了所有的输出值都为0, 使用于任何激活函数
Xavier 均匀分布:torch.nn.init.xavier_uniform_(tensor, gain = 1),
服从均匀分布U(-a, a), # 分布参数a=gain * sqrt(6 / (fan_in + fan_out)),
gain的大小由激活函数的类型来决定。 # 其中fan_in是指第i层神经元的个数,fan_out是指第i + 1层神经元的个数
```python for m in net.modules():
if isinstance(m, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.init.xavier_uniform_(m.weight)avier
for m in net.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.xavier_uniform_(m.weight, gain = torch.nn.init.calculate_gain('relu'))
```# Xavier 正态分布: torch.nn.init.xavier_normal_(tensor, gain = 1)
服从正态分布N(mean = 0, std),
# 其中 std = gain * sqrt(2 / (fan_in + fan_out)) Kaiming Initialization: 由于Xavier初始化在tanh中表现的很好, 在relu中表现很差, 因此Kaiming
Initialization针对Xavier在relu表现不佳被提出。基本思想仍然从“输入输出方差一致性”角度出发,在Relu网络中,假设每一层有一半的神经元被激活,另一半为0。一般在使用Relu的网络中推荐使用这种初始化方式。
# - kaiming均匀分布
# torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu') # 服从 U(-a, a), a =
sqrt(6 / (1 + b ^2) * fan_in), 其中b为激活函数的负半轴的斜率, relu是0
# model 可以是fan_in或者fan_out。fan_in 表示使正向传播时,方差一致; fan_out使反向传播时, 方差一致
# nonlinearity 可选为relu和leaky_relu, 默认是leaky_relu
# kaiming正态分布, N~ (0,std),其中std = sqrt(2/(1+b^2)*fan_in)**
```python
# torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
for m in net.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.kaiming_normal_(m.weight, mode = 'fan_in')
Orthogonal Initialization 正交初始化:
主要是解决神经网络中出现的梯度消失和梯度爆炸等问题,是RNN中常用的初始化方法
```python
for m in modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.orthogonal(m.weight)
常数初始化:
for m in modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.constant(m.weight, 0.5)
torch.nn.init.constant(m.bias, 0)
单层网络初始化方式:
conv = torch.nn.Conv2d(3, 64, kernel_size = 3, stride = 1, padding = 1)
torch.nn.init.xavier_uniform_(conv.weight)
torch.nn.init.xavier_uniform_(conv.bias, 0)
2.多层网络初始化方式:
使用apply函数; (2) 在_init_函数中用self.modules()进行初始化
#使用apply函数的方式进行初始化
#在weight_init中通过判断模块的类型来进行不同的参数初始化定义类型
def weight_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
torch.nn.init.xavier_normal_(m.weight.data)
torch.nn.init.constant_(m.bias.data, 0.0)
elif classname.find('Linear') != -1:
torch.nn.init.xavier_norrmal_(m.weight)
torch.nn.init.constant_(m.bias, 0.0)
model= Model();
model.apply(weight_init)
-
# 这里apply函数会递归搜索网络中所有Module, 并将weight_init函数应用在这些Module上
## 警告: 这种初始化方式采用的是递归的形式, 但是在Python中对递归的层数是有限制的。 ## 因此, 当网络的层数很深的时候, 可能会报错