理论
load_state_dict(state_dict)
功能:将 state_dict 中的参数加载到当前网络,常用于 finetune。
代码
# coding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
# ----------------------------------- load_state_dict
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 1, 3)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(1 * 3 * 3, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 1 * 3 * 3)
x = F.relu(self.fc1(x))
return x
def zero_param(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.constant_(m.weight.data, 0)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.constant_(m.weight.data, 0)
m.bias.data.zero_()
net = Net()
# 保存,并加载模型参数(仅保存模型参数)
torch.save(net.state_dict(), 'net_params.pkl') # 假设训练好了一个模型net
pretrained_dict = torch.load('net_params.pkl')
# 将net的参数全部置0,方便对比
net.zero_param()
net_state_dict = net.state_dict()
print('conv1层的权值为:\n', net_state_dict['conv1.weight'], '\n')
# 通过load_state_dict 加载参数
net.load_state_dict(pretrained_dict)
print('加载之后,conv1层的权值变为:\n', net_state_dict['conv1.weight'])
结果
conv1层的权值为:
tensor([[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]]])
加载之后,conv1层的权值变为:
tensor([[[[ 0.0316, 0.1448, -0.1615],
[-0.0211, 0.1815, 0.0439],
[-0.1738, -0.0357, 0.1491]],
[[ 0.0351, -0.1872, -0.1833],
[-0.0122, -0.0280, -0.1268],
[ 0.1247, 0.0967, -0.0031]],
[[ 0.1457, 0.0334, 0.0388],
[-0.1533, 0.0612, -0.1687],
[ 0.0567, -0.0991, 0.0273]]]])
该博客介绍了如何在PyTorch中使用`load_state_dict`函数加载模型参数,用于模型微调。首先定义了一个简单的卷积神经网络Net,然后将网络参数全部置零,通过`state_dict`保存和加载模型参数,展示了加载前后的权重变化。
2218

被折叠的 条评论
为什么被折叠?



