在深度学习中,获取网络的参数信息,对网络的finetune非常重要,下面将举一个简单的例子说明如何获取网络参数。
# coding: utf-8
import torch.nn as nn
import torch.nn.functional as F
# ----------------------------------- state_dict
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 1, 3)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(1 * 3 * 3, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 1 * 3 * 3)
x = F.relu(self.fc1(x))
return x
net = Net()
# 获取网络当前参数
net_state_dict = net.state_dict()
print('net_state_dict类型:', type(net_state_dict))
print('net_state_dict管理的参数: ', net_state_dict.keys())
for key, value in net_state_dict.items():
print('参数名: ', key, '\t大小: ', value.shape)
输出结果:
debug模型下,查看参数情况:
从上图结果可以看出,网络存储的其实就是一个字典数据。
1、 保存和加载网络参数:
# coding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
# ----------------------------------- load_state_dict
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 1, 3)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(1 * 3 * 3, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 1 * 3 * 3)
x = F.relu(self.fc1(x))
return x
def zero_param(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.constant_(m.weight.data, 0)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.constant_(m.weight.data, 0)
m.bias.data.zero_()
net = Net()
# 保存,并加载模型参数(仅保存模型参数)
torch.save(net.state_dict(), 'net_params.pkl') # 假设训练好了一个模型net
pretrained_dict = torch.load('net_params.pkl')
# 将net的参数全部置0,方便对比
net.zero_param()
net_state_dict = net.state_dict()
print('conv1层的权值为:\n', net_state_dict['conv1.weight'], '\n')
# 网络net通过load_state_dict 加载参数
net.load_state_dict(pretrained_dict)
print('加载之后,conv1层的权值变为:\n', net_state_dict['conv1.weight'])
输出结果如下:
2、设置多个参数组、添加参数组
# coding: utf-8
import torch
import torch.optim as optim
w1 = torch.randn(2, 2)
w1.requires_grad = True
w2 = torch.randn(2, 2)
w2.requires_grad = True
w3 = torch.randn(2, 2)
w3.requires_grad = True
# 一个参数组
optimizer_1 = optim.SGD([w1, w3], lr=0.1)
print('len(optimizer.param_groups): ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')
# 两个参数组
optimizer_2 = optim.SGD([{'params': w1, 'lr': 0.1},
{'params': w2, 'lr': 0.001}])
print('len(optimizer.param_groups): ', len(optimizer_2.param_groups))
print(optimizer_2.param_groups)
输出结果:
# coding: utf-8
import torch
import torch.optim as optim
# ----------------------------------- add_param_group
w1 = torch.randn(2, 2)
w1.requires_grad = True
w2 = torch.randn(2, 2)
w2.requires_grad = True
w3 = torch.randn(2, 2)
w3.requires_grad = True
# 一个参数组
optimizer_1 = optim.SGD([w1, w2], lr=0.1)
print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')
# 增加一个参数组
print('增加一组参数 w3\n')
optimizer_1.add_param_group({'params': w3, 'lr': 0.001, 'momentum': 0.8})
print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')
print('可以看到,参数组是一个list,一个元素是一个dict,每个dict中都有lr, momentum等参数,这些都是可单独管理,单独设定,十分灵活!')
输出结果: