[pytorch]权重初始化方法

权重初始化的方法封装在torch.nn.init里。具体在使用的时候先初始化层之后直接调。

初始化方法

常数初始化

w = torch.empty(3, 5)
nn.init.constant_(w, 0.3)

均匀分布

torch.nn.init.uniform_(tensor, a=0, b=1)
# a是分布的下界,b是上届

正态分布

torch.nn.init.normal_(tensor, mean=0, std=1)

稀疏初始化

torch.nn.init.sparse_(tensor, sparsity, std=0.01)

Xaivier 均匀分布

假设使用的是sigmoid函数。当权重值(值指的是绝对值)过小,输入值每经过网络层,方差都会减少,每一层的加权和很小,在sigmoid函数0附件的区域相当于线性函数,失去了DNN的非线性性。
当权重的值过大,输入值经过每一层后方差会迅速上升,每层的输出值将会很大,此时每层的梯度将会趋近于0.
xavier初始化可以使得输入值x方差经过网络层后的输出值y方差不变。

用一个均匀分布生成值,填充输入的张量或变量。结果张量中的值采样自U(-a, a),其中

a= gain * sqrt( 2/(fan_in + fan_out))* sqrt(3).

torch.nn.init.xavier_uniform_(tensor, gain=1)

#这里有一个gain,增益的大小是依据激活函数类型来设定,e.g.
nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))

Xaivier 正态分布

用一个正态分布生成值,填充输入的张量或变量。结果张量中的值采样自均值为0,标准差为gain * sqrt(2/(fan_in + fan_out))的正态分布。

torch.nn.init.xavier_normal_(tensor, gain=1)

Kaiming 均匀分布

Xavier在tanh中表现的很好,但在Relu激活函数中表现的很差,所何凯明提出了针对于relu的初始化方法。pytorch默认使用kaiming正态分布初始化卷积层参数。

用一个均匀分布生成值,填充输入的张量或变量。结果张量中的值采样自U(-bound, bound),其中:

bound = sqrt(6 / ((1+a**2)*fan_in) )

torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

Kaiming 正态分布

用一个正态分布生成值,填充输入的张量或变量。结果张量中的值采样自N(0,std)
N(0,std)的正态分布。

std = = sqrt(6 / ((1+a**2)*fan_in) )

torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

应用的实例

单层网络

在创建model后直接调用torch.nn.innit里的初始化函数

layer1 = torch.nn.Linear(10,20)
torch.nn.init.xavier_uniform_(layer1.weight)
torch.nn.init.constant_(layer1.bias, 0)

使用apply初始化

apply(fn):将fn函数递归地应用到网络模型的每个子模型中,主要用在参数的初始化。

使用apply()时,需要先定义一个参数初始化的函数。

def weight_init(m):
    classname = m.__class__.__name__ # 得到网络层的名字,如ConvTranspose2d
    if classname.find('Conv') != -1:  # 使用了find函数,如果不存在返回值为-1,所以让其不等于-1
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

之后,定义自己的网络,得到网络模型,使用apply()函数,就可以分别对conv层和bn层进行参数初始化。

model = net()
model.apply(weight_init)
已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 数字20 设计师:CSDN官方博客 返回首页