总结
- 单层网络初始化
直接调用torch.nn.innit里的初始化函数 - 多层网络初始化
2.1 使用apply和weight_init函数
2.2 在__init__函数使用self.modules()初始化
详细阐述
1. 单层网络
- 在创建model后直接调用torch.nn.innit里的初始化函数
layer1 = torch.nn.Linear(10,20)
torch.nn.init.xavier_uniform_(layer1.weight)
torch.nn.init.constant_(layer1.bias, 0)
- 也可以重写reset_parameters()方法,并不推荐
2. 多层网络
- 使用torch.nn.Module.apply.
apply(fn): 看一下apply在Module的实现。
将函数fn递归的运用在每个子模块上,这些子模块由self.children()返回.
常被用来初始化网络层参数。注意fn需要一个参数。
具体步骤是:
- 定义weight_init函数,并在weight_init中通过判断模块的类型来进行不同的参数初始化定义类型。
- model=Net(…) 创建网络结构
- model.apply(weight_init),将weight_init初始化方式应用到submodels上
在以下的代码中只初始化线性层,至于卷积层,批归一化层见后面例子
示例:官方示例1
# -*- coding: utf-8 -*-
import torch
from torch import nn
# hyper parameters
in_dim=1
n_hidden_1=1
n_hidden_2=1
out_dim=1
class Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(in_dim, n_hidden_1),
nn.ReLU(True),
nn.Linear(n_hidden_1, n_hidden_2),
nn.ReLU(True),
nn.Linear(n_hidden_2, out_dim)
)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
# 1. 根据网络层的不同定义不同的初始化方式
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
# 也可以判断是否为conv2d,使用相应的初始化方式
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# 是否为批归一化层
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# 2. 初始化网络结构
model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)
# 3. 将weight_init应用在子模块上
model.apply(weight_init)
注意:此种初始化方式采用的递归,而在python中,对递归层数是有限制的,所以当网络结构很深时,可能会递归层数过深的错误.
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight.data, 0, 0.01)
# m.weight.data.normal_(0,0.01)
m.bias.data.zero_()
先从self.modules()中遍历每一层,判断各层属于什么类型,是否是Conv2d,是否是BatchNorm2d,是否是Linear的,然后根据不同类型的层,设定不同的权值初始化方法,例如Xavier,kaiming,normal_等等。
- 在__init__中迭代循环self.modules()来初始化网络参数
此种方法的官方实例:官方示例2
初始化后并使用apply函数将参数信息打印出来
import torch
from torch import nn
# hyper parameters
in_dim=1
n_hidden_1=1
n_hidden_2=1
out_dim=1
class Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(in_dim, n_hidden_1),
nn.ReLU(True),
nn.Linear(n_hidden_1, n_hidden_2),
nn.ReLU(True),
nn.Linear(n_hidden_2, out_dim)
)
# 迭代循环初始化参数
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, -100)
# 也可以判断是否为conv2d,使用相应的初始化方式
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight.item(), 1)
nn.init.constant_(m.bias.item(), 0)
def forward(self, x):
x = self.layer(x)
return x
model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)
# 打印参数信息
def print_weight(m):
if isinstance(m, nn.Linear):
print("weight", m.weight.item())
print("bias:", m.bias.item())
print("next...")
model.apply(print_weight)
输出:
weight 1.0
bias: -100.0
next…weight 1.0
bias: -100.0
next…weight 1.0
bias: -100.0
next…
- 上面是根据对象类型来初始化,根据名字查找也可以:
# 权重初始化:根据不同类型的层,设定不同的权值初始化方法
def weights_init_normal(m): # m就是实例化的网络模型net对象,这样调用初始化权重 net.apply(weights_init_normal)
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal(m.weight.data, 1.0, 0.02)
torch.nn.init.constant(m.bias.data, 0.0)
torch.nn.init中的初始化方法有以下几种:
uniform = _make_deprecate(uniform_)
normal = _make_deprecate(normal_)
constant = _make_deprecate(constant_)
eye = _make_deprecate(eye_)
dirac = _make_deprecate(dirac_)
xavier_uniform = _make_deprecate(xavier_uniform_)
xavier_normal = _make_deprecate(xavier_normal_)
kaiming_uniform = _make_deprecate(kaiming_uniform_)
kaiming_normal = _make_deprecate(kaiming_normal_)
orthogonal = _make_deprecate(orthogonal_)
sparse = _make_deprecate(sparse_)
自己计算std
如果不想使用官方提供的Xavier和kaiming初始化呢?
自己计算std,然后调用torch.nn.init.normal_(tensor, maen, std)就OK了。
至于官方的kaiming和Xavier是如何计算fan_in(输入神经元个数)和fan_out(输出神经元个数),参考这篇博客
还需要注意一点,在pytorch 0.4中不存在Variable了,所以weight和bias就直接是tensor类型了。
参考链接: https://blog.csdn.net/dss_dssssd/article/details/83990511