pytorch快速搭建卷积神经网络【优化器_第4课_加载参数到网络_load_state_dict】

该博客介绍了如何在PyTorch中使用`load_state_dict`函数加载模型参数,用于模型微调。首先定义了一个简单的卷积神经网络Net,然后将网络参数全部置零,通过`state_dict`保存和加载模型参数,展示了加载前后的权重变化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

理论

load_state_dict(state_dict)
功能:将 state_dict 中的参数加载到当前网络,常用于 finetune。

代码

# coding: utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F


# ----------------------------------- load_state_dict

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(1 * 3 * 3, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 1 * 3 * 3)
        x = F.relu(self.fc1(x))
        return x

    def zero_param(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.constant_(m.weight.data, 0)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.constant_(m.weight.data, 0)
                m.bias.data.zero_()
net = Net()

# 保存,并加载模型参数(仅保存模型参数)
torch.save(net.state_dict(), 'net_params.pkl')   # 假设训练好了一个模型net
pretrained_dict = torch.load('net_params.pkl')

# 将net的参数全部置0,方便对比
net.zero_param()
net_state_dict = net.state_dict()
print('conv1层的权值为:\n', net_state_dict['conv1.weight'], '\n')

# 通过load_state_dict 加载参数
net.load_state_dict(pretrained_dict)
print('加载之后,conv1层的权值变为:\n', net_state_dict['conv1.weight'])

结果

conv1层的权值为:
 tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]]) 

加载之后,conv1层的权值变为:
 tensor([[[[ 0.0316,  0.1448, -0.1615],
           [-0.0211,  0.1815,  0.0439],
           [-0.1738, -0.0357,  0.1491]],

          [[ 0.0351, -0.1872, -0.1833],
           [-0.0122, -0.0280, -0.1268],
           [ 0.1247,  0.0967, -0.0031]],

          [[ 0.1457,  0.0334,  0.0388],
           [-0.1533,  0.0612, -0.1687],
           [ 0.0567, -0.0991,  0.0273]]]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

【网络星空】

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值