本文介绍使用pytorch来对模型权重进行初始化的两种方式
在学习过程中借鉴了如下文章的代码
链接: pytorch 实现初始化操作详细讲解 常用方案
第一种方式,在类的初始化函数中使用
第二种方式,在类外定义一个用于初始化模型参数的函数
第一种方式
在类的初始化函数中使用
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3,1,1)
self.conv2 = nn.Conv2d(64, 128, 3,1,1)
self.pool = nn.AvgPool2d(2,2)
self.fc = nn.Linear(128*32*32, 10)
# 在此处添加初始化参数的代码
for m in self.modules():
if isinstance(m, nn.Conv2d):
print('初始化卷积层参数')
nn.init.kaiming_uniform_(m.weight.data, mode='fan_in', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
print('初始化BN层参数')
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = self.pool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
if __name__ == '__main__':
# 测试一下模型能不能跑通
input = torch.ones(8,3,64,64)
net = Net()
out = net(input)
print("out.shape:{}".format(out.shape))
第二种方式
在类外定义一个用于初始化模型参数的函数
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3,1,1)
self.conv2 = nn.Conv2d(64, 128, 3,1,1)
self.pool = nn.AvgPool2d(2,2)
self.fc = nn.Linear(128*32*32, 10)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = self.pool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
# 定义初始化权重的函数
def weights_init_normal_kaiming(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
print('初始化卷积层参数')
nn.init.kaiming_uniform_(m.weight.data, mode='fan_in', nonlinearity='relu')
elif classname.find("BatchNorm2d") != -1:
print('初始化BN层参数')
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
if __name__ == '__main__':
input = torch.ones(8,3,64,64)
net = Net()
# 在实例化模型后,初始化模型的权重
net.apply(weights_init_normal_kaiming)
out = net(input)
print("out.shape:{}".format(out.shape))
本文用于记录学习过程,若有错误,望批评指正