Pytorch中nn.Module模块一些解析

一、 model.state_dict()

pytorch 中的 state_dict 是一个简单的python的字典对象(collections.OrderedDict 有序字典,可参考https://www.cnblogs.com/single-boy/p/7446293.html),将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

注意:

(1)state_dict中保存是所有带参数的对象,包括可训练及不可训练的,如卷积层,线性层,BN层等等,像什么池化层、激活函数层这些本身没有参数的层是没有在这个字典中的;

(2)这个方法的作用一方面是方便查看某一个层的权值和偏置数据,另一方面更多的是在模型保存的时候使用。
代码示例

import torchvision.models as models
model = models.resnet18()
for k,v in model.state_dict().items():
    print(k,v.shape)
#输出如下:
conv1.weight torch.Size([64, 3, 7, 7])
bn1.weight torch.Size([64])
bn1.bias torch.Size([64])
bn1.running_mean torch.Size([64])
bn1.running_var torch.Size([64])
bn1.num_batches_tracked torch.Size([])
layer1.0.conv1.weight torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight torch.Size([64])
layer1.0.bn1.bias torch.Size([64])
layer1.0.bn1.running_mean torch.Size([64])
layer1.0.bn1.running_var torch.Size([64])
layer1.0.bn1.num_batches_tracked torch.Size([])
layer1.0.conv2.weight torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight torch.Size([64])
layer1.0.bn2.bias torch.Size([64])
layer1.0.bn2.running_mean torch.Size([64])
layer1.0.bn2.running_var torch.Size([64])
layer1.0.bn2.num_batches_tracked torch.Size([])
layer1.1.conv1.weight torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight torch.Size([64])
layer1.1.bn1.bias torch.Size([64])
layer1.1.bn1.running_mean torch.Size([64])
layer1.1.bn1.running_var torch.Size([64])
layer1.1.bn1.num_batches_tracked torch.Size([])
layer1.1.conv2.weight torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight torch.Size([64])
layer1.1.bn2.bias torch.Size([64])
layer1.1.bn2.running_mean torch.Size([64])
layer1.1.bn2.running_var torch.Size([64])
layer1.1.bn2.num_batches_tracked torch.Size([])
layer2.0.conv1.weight torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight torch.Size([128])
layer2.0.bn1.bias torch.Size([128])
layer2.0.bn1.running_mean torch.Size([128])
layer2.0.bn1.running_var torch.Size([128])
layer2.0.bn1.num_batches_tracked torch.Size([])
layer2.0.conv2.weight torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight torch.Size([128])
layer2.0.bn2.bias torch.Size([128])
layer2.0.bn2.running_mean torch.Size([128])
layer2.0.bn2.running_var torch.Size([128])
layer2.0.bn2.num_batches_tracked torch.Size([])
layer2.0.downsample.0.weight torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight torch.Size([128])
layer2.0.downsample.1.bias torch.Size([128])
layer2.0.downsample.1.running_mean torch.Size([128])
layer2.0.downsample.1.running_var torch.Size([128])
layer2.0.downsample.1.num_batches_tracked torch.Size([])
layer2.1.conv1.weight torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight torch.Size([128])
layer2.1.bn1.bias torch.Size([128])
layer2.1.bn1.running_mean torch.Size([128])
layer2.1.bn1.running_var torch.Size([128])
layer2.1.bn1.num_batches_tracked torch.Size([])
layer2.1.conv2.weight torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight torch.Size([128])
layer2.1.bn2.bias torch.Size([128])
layer2.1.bn2.running_mean torch.Size([128])
layer2.1.bn2.running_var torch.Size([128])
layer2.1.bn2.num_batches_tracked torch.Size([])
layer3.0.conv1.weight torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight torch.Size([256])
layer3.0.bn1.bias torch.Size([256])
layer3.0.bn1.running_mean torch.Size([256])
layer3.0.bn1.running_var torch.Size([256])
layer3.0.bn1.num_batches_tracked torch.Size([])
layer3.0.conv2.weight torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight torch.Size([256])
layer3.0.bn2.bias torch.Size([256])
layer3.0.bn2.running_mean torch.Size([256])
layer3.0.bn2.running_var torch.Size([256])
layer3.0.bn2.num_batches_tracked torch.Size([])
layer3.0.downsample.0.weight torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight torch.Size([256])
layer3.0.downsample.1.bias torch.Size([256])
layer3.0.downsample.1.running_mean torch.Size([256])
layer3.0.downsample.1.running_var torch.Size([256])
layer3.0.downsample.1.num_batches_tracked torch.Size([])
layer3.1.conv1.weight torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight torch.Size([256])
layer3.1.bn1.bias torch.Size([256])
layer3.1.bn1.running_mean torch.Size([256])
layer3.1.bn1.running_var torch.Size([256])
layer3.1.bn1.num_batches_tracked torch.Size([])
layer3.1.conv2.weight torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight torch.Size([256])
layer3.1.bn2.bias torch.Size([256])
layer3.1.bn2.running_mean torch.Size([256])
layer3.1.bn2.running_var torch.Size([256])
layer3.1.bn2.num_batches_tracked torch.Size([])
layer4.0.conv1.weight torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight torch.Size([512])
layer4.0.bn1.bias torch.Size([512])
layer4.0.bn1.running_mean torch.Size([512])
layer4.0.bn1.running_var torch.Size([512])
layer4.0.bn1.num_batches_tracked torch.Size([])
layer4.0.conv2.weight torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight torch.Size([512])
layer4.0.bn2.bias torch.Size([512])
layer4.0.bn2.running_mean torch.Size([512])
layer4.0.bn2.running_var torch.Size([512])
layer4.0.bn2.num_batches_tracked torch.Size([])
layer4.0.downsample.0.weight torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight torch.Size([512])
layer4.0.downsample.1.bias torch.Size([512])
layer4.0.downsample.1.running_mean torch.Size([512])
layer4.0.downsample.1.running_var torch.Size([512])
layer4.0.downsample.1.num_batches_tracked torch.Size([])
layer4.1.conv1.weight torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight torch.Size([512])
layer4.1.bn1.bias torch.Size([512])
layer4.1.bn1.running_mean torch.Size([512])
layer4.1.bn1.running_var torch.Size([512])
layer4.1.bn1.num_batches_tracked torch.Size([])
layer4.1.conv2.weight torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight torch.Size([512])
layer4.1.bn2.bias torch.Size([512])
layer4.1.bn2.running_mean torch.Size([512])
layer4.1.bn2.running_var torch.Size([512])
layer4.1.bn2.num_batches_tracked torch.Size([])
fc.weight torch.Size([1000, 512])
fc.bias torch.Size([1000])

对于这里的参数需要注意,k是对应的名称,而v只是一个tensor,并不是变量,没有一些别的属性

1 优化器optimizer对应的state_dict()

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等),这里的param_groups对应的是下面介绍的model.named_parameters(),是可训练的参数
代码如下:

import torchvision.models as models
import torch
model = models.resnet18()
optimizer = torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

for k,v in optimizer.state_dict().items():
    print(k,v)
# 输出如下
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [139949035513088, 139949035513008, 139949035513248, 139949035513888, 139949035513808, 139949035514048, 139949035514848, 139949035514768, 139949035515008, 139949035515648, 139949035515568, 139949035515808, 139949035561760, 139949035562000, 139949035562080, 139949035564160, 139949035486416, 139949035564320, 139949035593968, 139949035594208, 139949035594288, 139949035562800, 139949035563040, 139949035563120, 139949035595168, 139949035595408, 139949035595488, 139949035596368, 139949035596608, 139949035596688, 139949035627536, 139949035627776, 139949035627856, 139949035628736, 139949035628976, 139949035629056, 139949035597408, 139949035597648, 139949035597728, 139949035629936, 139949035630176, 139949035630256, 139949035672192, 139949035672432, 139949035672512, 139949035674592, 139949035674832, 139949035674912, 139949035704560, 139949035704800, 139949035704880, 139949035673232, 139949035673472, 139949035673552, 139949035705760, 139949035706000, 139949035706080, 139949035706960, 139949035707200, 139949035707280, 139949035707840, 139949035707920]}]

二、model.parameters()和model.named_parameters()

其实这个state_dict方法所得到结果差不多,不同的是,model.parameters()方法是只有可训练参数(例如bn层中的weight和bias,而对于其中的running_mean和running_var这种通过forward计算的不在这个方法里面),返回的是一个生成器generator,每一个元素是从开头到结尾的参数,parameters没有对应的key名称,是一个由纯参数组成的generator,而state_dict是一个字典,包含了一个key。而且每个parameter是一个variable变量,包含required_grad属性。同时对应named_parameters则是包含了名称和对应的parameter
代码如下:

import torchvision.models as models
import torch
model = models.resnet18()

for k,v in model.named_parameters():
    print(k,v.shape)
# 输出如下
conv1.weight torch.Size([64, 3, 7, 7])
bn1.weight torch.Size([64])
bn1.bias torch.Size([64])
layer1.0.conv1.weight torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight torch.Size([64])
layer1.0.bn1.bias torch.Size([64])
layer1.0.conv2.weight torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight torch.Size([64])
layer1.0.bn2.bias torch.Size([64])
layer1.1.conv1.weight torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight torch.Size([64])
layer1.1.bn1.bias torch.Size([64])
layer1.1.conv2.weight torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight torch.Size([64])
layer1.1.bn2.bias torch.Size([64])
layer2.0.conv1.weight torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight torch.Size([128])
layer2.0.bn1.bias torch.Size([128])
layer2.0.conv2.weight torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight torch.Size([128])
layer2.0.bn2.bias torch.Size([128])
layer2.0.downsample.0.weight torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight torch.Size([128])
layer2.0.downsample.1.bias torch.Size([128])
layer2.1.conv1.weight torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight torch.Size([128])
layer2.1.bn1.bias torch.Size([128])
layer2.1.conv2.weight torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight torch.Size([128])
layer2.1.bn2.bias torch.Size([128])
layer3.0.conv1.weight torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight torch.Size([256])
layer3.0.bn1.bias torch.Size([256])
layer3.0.conv2.weight torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight torch.Size([256])
layer3.0.bn2.bias torch.Size([256])
layer3.0.downsample.0.weight torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight torch.Size([256])
layer3.0.downsample.1.bias torch.Size([256])
layer3.1.conv1.weight torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight torch.Size([256])
layer3.1.bn1.bias torch.Size([256])
layer3.1.conv2.weight torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight torch.Size([256])
layer3.1.bn2.bias torch.Size([256])
layer4.0.conv1.weight torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight torch.Size([512])
layer4.0.bn1.bias torch.Size([512])
layer4.0.conv2.weight torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight torch.Size([512])
layer4.0.bn2.bias torch.Size([512])
layer4.0.downsample.0.weight torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight torch.Size([512])
layer4.0.downsample.1.bias torch.Size([512])
layer4.1.conv1.weight torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight torch.Size([512])
layer4.1.bn1.bias torch.Size([512])
layer4.1.conv2.weight torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight torch.Size([512])
layer4.1.bn2.bias torch.Size([512])
fc.weight torch.Size([1000, 512])
fc.bias torch.Size([1000])

1、freeze某些参数

如果想某些层不训练,代码如下:

import torchvision.models as models
import torch
model = models.resnet18()

for k,v in model.named_parameters():
    if k.startswith('conv1'):
        v.requires_grad=False
for k,v in model.named_parameters():

    print(k,v.requires_grad)
# 输出如下:
conv1.weight False
bn1.weight True
bn1.bias True
layer1.0.conv1.weight True
layer1.0.bn1.weight True
layer1.0.bn1.bias True
layer1.0.conv2.weight True
layer1.0.bn2.weight True
layer1.0.bn2.bias True
layer1.1.conv1.weight True
layer1.1.bn1.weight True
layer1.1.bn1.bias True
layer1.1.conv2.weight True
layer1.1.bn2.weight True
layer1.1.bn2.bias True
layer2.0.conv1.weight True
layer2.0.bn1.weight True
layer2.0.bn1.bias True
layer2.0.conv2.weight True
layer2.0.bn2.weight True
layer2.0.bn2.bias True
layer2.0.downsample.0.weight True
layer2.0.downsample.1.weight True
layer2.0.downsample.1.bias True
layer2.1.conv1.weight True
layer2.1.bn1.weight True
layer2.1.bn1.bias True
layer2.1.conv2.weight True
layer2.1.bn2.weight True
layer2.1.bn2.bias True
layer3.0.conv1.weight True
layer3.0.bn1.weight True
layer3.0.bn1.bias True
layer3.0.conv2.weight True
layer3.0.bn2.weight True
layer3.0.bn2.bias True
layer3.0.downsample.0.weight True
layer3.0.downsample.1.weight True
layer3.0.downsample.1.bias True
layer3.1.conv1.weight True
layer3.1.bn1.weight True
layer3.1.bn1.bias True
layer3.1.conv2.weight True
layer3.1.bn2.weight True
layer3.1.bn2.bias True
layer4.0.conv1.weight True
layer4.0.bn1.weight True
layer4.0.bn1.bias True
layer4.0.conv2.weight True
layer4.0.bn2.weight True
layer4.0.bn2.bias True
layer4.0.downsample.0.weight True
layer4.0.downsample.1.weight True
layer4.0.downsample.1.bias True
layer4.1.conv1.weight True
layer4.1.bn1.weight True
layer4.1.bn1.bias True
layer4.1.conv2.weight True
layer4.1.bn2.weight True
layer4.1.bn2.bias True
fc.weight True
fc.bias True

三、nn.Module里面一些其他方法

例如model.children()、model.named_children()、model.modules()、model.named_modules(),具体参考https://blog.csdn.net/qq_27825451/article/details/90550890

四、对于BN层固定

对于BN层,主要有2部分参数,一个是可训练的 weight和bias,及通过forward计算的running_mean和running_val,如果通过freeze方法,只能固定weight和bias,无法固定running_mean和running_val,这样在训练之后还是会造成模型差异;针对BN层,全部固定的话,可以参考https://zhuanlan.zhihu.com/p/259160576,然后可以根据modules的名称,进行BN层固定,参考代码:

for k, v in model.named_modules(): # 如果是nn.parallel.DataParallel, nn.parallel.DistributedDataParallel,则使用model.module.named_modules()
    # print(k)
    if '.bn' in k:
        v.eval()
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值