pytorch中关于网络参数的相关操作

在深度学习中,获取网络的参数信息,对网络的finetune非常重要,下面将举一个简单的例子说明如何获取网络参数。

# coding: utf-8

import torch.nn as nn
import torch.nn.functional as F


# ----------------------------------- state_dict
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(1 * 3 * 3, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 1 * 3 * 3)
        x = F.relu(self.fc1(x))
        return x


net = Net()

# 获取网络当前参数
net_state_dict = net.state_dict()

print('net_state_dict类型:', type(net_state_dict))
print('net_state_dict管理的参数: ', net_state_dict.keys())
for key, value in net_state_dict.items():
    print('参数名: ', key, '\t大小: ',  value.shape)

输出结果:

 debug模型下,查看参数情况:

 从上图结果可以看出,网络存储的其实就是一个字典数据。

1、 保存和加载网络参数:


# coding: utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F


# ----------------------------------- load_state_dict

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(1 * 3 * 3, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 1 * 3 * 3)
        x = F.relu(self.fc1(x))
        return x

    def zero_param(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.constant_(m.weight.data, 0)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.constant_(m.weight.data, 0)
                m.bias.data.zero_()
net = Net()

# 保存,并加载模型参数(仅保存模型参数)
torch.save(net.state_dict(), 'net_params.pkl')   # 假设训练好了一个模型net
pretrained_dict = torch.load('net_params.pkl')

# 将net的参数全部置0,方便对比
net.zero_param()
net_state_dict = net.state_dict()
print('conv1层的权值为:\n', net_state_dict['conv1.weight'], '\n')

# 网络net通过load_state_dict 加载参数
net.load_state_dict(pretrained_dict)
print('加载之后,conv1层的权值变为:\n', net_state_dict['conv1.weight'])

输出结果如下:

 2、设置多个参数组、添加参数组


# coding: utf-8

import torch
import torch.optim as optim


w1 = torch.randn(2, 2)
w1.requires_grad = True

w2 = torch.randn(2, 2)
w2.requires_grad = True

w3 = torch.randn(2, 2)
w3.requires_grad = True

# 一个参数组
optimizer_1 = optim.SGD([w1, w3], lr=0.1)
print('len(optimizer.param_groups): ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')

# 两个参数组
optimizer_2 = optim.SGD([{'params': w1, 'lr': 0.1},
                         {'params': w2, 'lr': 0.001}])
print('len(optimizer.param_groups): ', len(optimizer_2.param_groups))
print(optimizer_2.param_groups)

输出结果:

# coding: utf-8

import torch
import torch.optim as optim

# ----------------------------------- add_param_group

w1 = torch.randn(2, 2)
w1.requires_grad = True

w2 = torch.randn(2, 2)
w2.requires_grad = True

w3 = torch.randn(2, 2)
w3.requires_grad = True

# 一个参数组
optimizer_1 = optim.SGD([w1, w2], lr=0.1)
print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')

# 增加一个参数组
print('增加一组参数 w3\n')
optimizer_1.add_param_group({'params': w3, 'lr': 0.001, 'momentum': 0.8})

print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')

print('可以看到,参数组是一个list,一个元素是一个dict,每个dict中都有lr, momentum等参数,这些都是可单独管理,单独设定,十分灵活!')

输出结果:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值