pytorch教程之nn.Module类详解——state_dict和parameters两个方法的差异性比较

前言:pytorch的模块Module类有很多的方法,前面的文章中已经介绍了四个常用的方法,这四个方法可以用于获取模块中所定义的对象(即每一个层)他们分别是children()、named_children()、modules()、named_modules()方法,本文介绍另外两个重要的方法,这两个方法会获取到模型中训练的参数(权值矩阵、偏置bias),这两个方法是model.state_dict()方法和model.parameters()方法。前面的文章参考:pytorch教程之nn.Module类详解——使用Module类来自定义模型

一、本文的模型案例

为了简单的演示,本文的模型较为简单,代码如下:

import torch
import torch.nn.functional as F
from torch.optim import SGD

class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()  # 第一句话,调用父类的构造函数
        self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu1=torch.nn.ReLU()
        self.max_pooling1=torch.nn.MaxPool2d(2,1)

        self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu2=torch.nn.ReLU()
        self.max_pooling2=torch.nn.MaxPool2d(2,1)

        self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
        self.dense2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.max_pooling1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.max_pooling2(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x

model = MyNet() # 构造模型

二、model.state_dict()方法

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

注意:

(1)只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等,像什么池化层、BN层这些本身没有参数的层是没有在这个字典中的;

(2)这个方法的作用一方面是方便查看某一个层的权值和偏置数据,另一方面更多的是在模型保存的时候使用。

2.1 Module的层的权值以及bias查看

print(type(model.state_dict()))  # 查看state_dict所返回的类型,是一个“顺序字典OrderedDict”

for param_tensor in model.state_dict(): # 字典的遍历默认是遍历 key,所以param_tensor实际上是键值
    print(param_tensor,'\t',model.state_dict()[param_tensor].size())

'''
conv1.weight     torch.Size([32, 3, 3, 3])
conv1.bias       torch.Size([32])
conv2.weight     torch.Size([32, 3, 3, 3])
conv2.bias       torch.Size([32])
dense1.weight    torch.Size([128, 288])
dense1.bias      torch.Size([128])
dense2.weight    torch.Size([10, 128])
dense2.bias      torch.Size([10])
'''

当然这里之查看每一个参数的维度信息,如果是查看具体的值也是一样的,跟字典的操作是一样的。

2.2 优化器optimizer的state_dict()方法

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

optimizer = SGD(model.parameters(),lr=0.001,momentum=0.9)
for var_name in optimizer.state_dict():
    print(var_name,'\t',optimizer.state_dict()[var_name])

'''
state    {}
param_groups     [{'lr': 0.001, 
                   'momentum': 0.9,
                   'dampening': 0, 
                   'weight_decay': 0, 
                   'nesterov': False, 
                   'params': [1412966600640, 1412966613064, 1412966613136, 1412966613208, 
                             1412966613280, 1412966613352, 1412966613496, 1412966613568]
                 }]
'''

三、model.parameters()方法

这个方法也会获得模型的参数信息,如下:

print(type(model.parameters()))  # 返回的是一个generator

for para in model.parameters():
    print(para.size())  # 只查看形状
'''
torch.Size([32, 3, 3, 3])
torch.Size([32])
torch.Size([32, 3, 3, 3])
torch.Size([32])
torch.Size([128, 288])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
'''

从这里可以看出,其实这个state_dict方法所得到结果差不多,不同的是,model.parameters()方法返回的是一个生成器generator,每一个元素是从开头到结尾的参数,parameters没有对应的key名称,是一个由纯参数组成的generator,而state_dict是一个字典,包含了一个key

其实Module还有一个与parameters类似的函数,named_parameters,而且parameters正是通过named_parameters来实现的,

看一下parameters的定义,很简单:

def parameters(self, recurse=True):
    for name, param in self.named_parameters(recurse=recurse):
        yield param

来一起看一下named_parameters的简单使用。

print(type(model.named_parameters()))  # 返回的是一个generator

for para in model.named_parameters(): # 返回的每一个元素是一个元组 tuple 
    '''
    是一个元组 tuple ,元组的第一个元素是参数所对应的名称,第二个元素就是对应的参数值
    '''
    print(para[0],'\t',para[1].size())

'''
conv1.weight     torch.Size([32, 3, 3, 3])
conv1.bias       torch.Size([32])
conv2.weight     torch.Size([32, 3, 3, 3])
conv2.bias       torch.Size([32])
dense1.weight    torch.Size([128, 288])
dense1.bias      torch.Size([128])
dense2.weight    torch.Size([10, 128])
dense2.bias      torch.Size([10])
'''

总结:model.state_dict()、model.parameters()、model.named_parameters()这三个方法都可以查看Module的参数信息,用于更新参数,或者用于模型的保存。

 

 

  • 31
    点赞
  • 121
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值