PyTorch:模型参数读取与设置--以FlowNetSimple为例

一、背景

在“搞”深度学习时,除非富如东海,往往都不会直接用大量数据来训练一个网络;一般情况下,比较省钱且高效的思路是利用一些预训练的模型,并在其基础上进行再训练优化,达到自己的目的。
因此,在本博客中将简单记录一下,如何在PyTorch基础上读取预训练模型的参数,并添加到自己的模型中去,从而尽可能减少自己的计算量。
为了直接讲明整个过程,本文设计了一个实验,首先设计了一个网络,其前半部分与FlowNetSimple的Encode一致,后半部分是全连接的分类网络。
下图是FlowNetSimple的网络结构,其中的refinement部分是Decode结构(类似UNet)
在这里插入图片描述
本文设计的结构,其实就是把Decode给删除了,换成全连接,代码就不贴了,很容易。
在这里插入图片描述


二、参考链接

https://github.com/NVIDIA/flownet2-pytorch
《Dive into DL PyTorch》


三、操作过程

3.1 下载预训练模型

此处我用的预训练模型来自https://github.com/NVIDIA/flownet2-pytorch此网页下的FlowNetS
如果不是很了解,FlowNetSimple,其对应的代码如下,我简单注释一下

Learn more or give us feedback
'''
Portions of this code copyright 2017, Clement Pinard
'''

import torch
import torch.nn as nn
from torch.nn import init

import math
import numpy as np

from .submodules import *
'Parameter count : 38,676,504 '

class FlowNetS(nn.Module):
   def __init__(self, args, input_channels = 12, batchNorm=True):
       super(FlowNetS,self).__init__()

   	# 以下全部都是Encode部分,conv是这个代码自行封装的,等同于conv2d+ReLU
       self.batchNorm = batchNorm
       self.conv1   = conv(self.batchNorm,  input_channels,   64, kernel_size=7, stride=2)
       self.conv2   = conv(self.batchNorm,  64,  128, kernel_size=5, stride=2)
       self.conv3   = conv(self.batchNorm, 128,  256, kernel_size=5, stride=2)
       self.conv3_1 = conv(self.batchNorm, 256,  256)
       self.conv4   = conv(self.batchNorm, 256,  512, stride=2)
       self.conv4_1 = conv(self.batchNorm, 512,  512)
       self.conv5   = conv(self.batchNorm, 512,  512, stride=2)
       self.conv5_1 = conv(self.batchNorm, 512,  512)
       self.conv6   = conv(self.batchNorm, 512, 1024, stride=2)
       self.conv6_1 = conv(self.batchNorm,1024, 1024)

      # 以下是Decode部分,deconv是向上卷积部分
       self.deconv5 = deconv(1024,512)
       self.deconv4 = deconv(1026,256)
       self.deconv3 = deconv(770,128)
       self.deconv2 = deconv(386,64)

       # 这些部分不用关心,这是用于预测光流的
       self.predict_flow6 = predict_flow(1024)
       self.predict_flow5 = predict_flow(1026)
       self.predict_flow4 = predict_flow(770)
       self.predict_flow3 = predict_flow(386)
       self.predict_flow2 = predict_flow(194)
       self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
       self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
       self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
       self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)

      # 初始化
       for m in self.modules():
           if isinstance(m, nn.Conv2d):
               if m.bias is not None:
                   init.uniform_(m.bias)
               init.xavier_uniform_(m.weight)

           if isinstance(m, nn.ConvTranspose2d):
               if m.bias is not None:
                   init.uniform_(m.bias)
               init.xavier_uniform_(m.weight)
               # init_deconv_bilinear(m.weight)
       self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')

   def forward(self, x):
       out_conv1 = self.conv1(x)

       # Encode
       out_conv2 = self.conv2(out_conv1)
       out_conv3 = self.conv3_1(self.conv3(out_conv2))
       out_conv4 = self.conv4_1(self.conv4(out_conv3))
       out_conv5 = self.conv5_1(self.conv5(out_conv4))
       out_conv6 = self.conv6_1(self.conv6(out_conv5))

      # 这个地方如果看过FlowNet的论文就知道,Decode的每一层都会返回一个光流结果,不同尺寸的
       flow6       = self.predict_flow6(out_conv6)
       flow6_up    = self.upsampled_flow6_to_5(flow6)
       out_deconv5 = self.deconv5(out_conv6)
       
       concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
       flow5       = self.predict_flow5(concat5)
       flow5_up    = self.upsampled_flow5_to_4(flow5)
       out_deconv4 = self.deconv4(concat5)
       
       concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
       flow4       = self.predict_flow4(concat4)
       flow4_up    = self.upsampled_flow4_to_3(flow4)
       out_deconv3 = self.deconv3(concat4)
       
       concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
       flow3       = self.predict_flow3(concat3)
       flow3_up    = self.upsampled_flow3_to_2(flow3)
       out_deconv2 = self.deconv2(concat3)

       concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
       flow2 = self.predict_flow2(concat2)

       if self.training:
           return flow2,flow3,flow4,flow5,flow6
       else:
           return flow2,
3.2 查看预模型参数

千辛万苦下载好预训练模型后,接下来就可以开始准备获取预训练参数了。不过,首先让我们简单看一下,这个预训练结果里都有啥。提醒一下,以上链接下载下来的文件名称为FlowNet2-S_checkpoint.pth.tar

3.2.1 读取模型
# 读取预训练模型并不一定要先声明model,完全可以先读取内容
state_dict = torch.load('FlowNet2-S_checkpoint.pth.tar')
3.2.2 打印state_dict 信息
# 这里补充一下,我的感觉是PyTorch是用类似json序列化的方式在保存模型,所以其核心就是key-value
for k, v in state_dict.items():
    print(k)

输出

epoch
best_EPE
state_dict

看到这个结果,我的第一反应其实是懵逼的;但是很快反应过来,其中epochbest_EPE存储了训练时的一些信息,这些我们并不感兴趣。state_dict才是我们真正感兴趣的。

3.2.3 打印state_dict[‘state_dict’]信息
for k, v in state_dict['state_dict'].items():
    print(k)

输出

conv1.0.weight
conv1.0.bias
conv2.0.weight
conv2.0.bias
conv3.0.weight
conv3.0.bias
conv3_1.0.weight
conv3_1.0.bias
conv4.0.weight
conv4.0.bias
conv4_1.0.weight
conv4_1.0.bias
conv5.0.weight
conv5.0.bias
conv5_1.0.weight
conv5_1.0.bias
conv6.0.weight
conv6.0.bias
conv6_1.0.weight
conv6_1.0.bias
deconv5.0.weight
deconv5.0.bias
deconv4.0.weight
deconv4.0.bias
deconv3.0.weight
deconv3.0.bias
deconv2.0.weight
deconv2.0.bias
predict_flow6.weight
predict_flow6.bias
predict_flow5.weight
predict_flow5.bias
predict_flow4.weight
predict_flow4.bias
predict_flow3.weight
predict_flow3.bias
predict_flow2.weight
predict_flow2.bias
upsampled_flow6_to_5.weight
upsampled_flow5_to_4.weight
upsampled_flow4_to_3.weight
upsampled_flow3_to_2.weight

不难发现,上述输出的key与FlowNetSimple的模型一一对应,这为我们后续读取打好了基础。

3.3 查看设计模型参数

此处没啥好说的,不过仍旧可以访问一下定义好的模型,看看都有啥参数。

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

输出

conv1.0.weight
conv1.0.bias
conv2.0.weight
conv2.0.bias
conv3.0.weight
conv3.0.bias
conv3_1.0.weight
conv3_1.0.bias
conv4.0.weight
conv4.0.bias
conv4_1.0.weight
conv4_1.0.bias
conv5.0.weight
conv5.0.bias
conv5_1.0.weight
conv5_1.0.bias
conv6.0.weight
conv6.0.bias
conv6_1.0.weight
conv6_1.0.bias
fc_1.0.weight
fc_1.0.bias
fc_2.0.weight
fc_2.0.bias

不难看出,其大部分结构与FlowNetSimple相同,唯一不同的最后的两个全连接层。

3.4 模型参数赋值

一般情况下,如果预训练模型和自己训练的模型完全相同,那么直接model.load_state_dict(torch.load(PATH))即可。但是在此处呢,预训练模型和自己的模型不一致,这意味着我们仅仅需要其中一部分参数,另外的则不感兴趣。此处,博主用了一个非常懒的方式,即逐个对对应的模块进行赋值。

3.4.1 访问预训练参数

对于上述读取的state_dict,其模型的参数也是key-value的模式,读取方式如下:

# 后部分的参数可以改成对应的层
state_dict['state_dict']['conv2.0.bias']
3.4.2 访问自己模型的参数

在本代码中,定义好的模型每一层是其的一个属性,因此其访问模型如下:

# 这里其实也有我用了Sequential的原因,但思路差不多
model.conv2[0].bias.data
3.4.3 赋值

原谅如此暴力的我!

model.conv1[0].weight.data = state_dict['state_dict']['conv1.0.weight']
model.conv1[0].bias.data = state_dict['state_dict']['conv1.0.bias']
model.conv2[0].weight.data = state_dict['state_dict']['conv2.0.weight']
model.conv2[0].bias.data = state_dict['state_dict']['conv2.0.bias']

model.conv3[0].weight.data = state_dict['state_dict']['conv3.0.weight']
model.conv3[0].bias.data = state_dict['state_dict']['conv3.0.bias']
model.conv3_1[0].weight.data = state_dict['state_dict']['conv3_1.0.weight']
model.conv3_1[0].bias.data = state_dict['state_dict']['conv3_1.0.bias']

model.conv4[0].weight.data = state_dict['state_dict']['conv4.0.weight']
model.conv4[0].bias.data = state_dict['state_dict']['conv4.0.bias']
model.conv4_1[0].weight.data = state_dict['state_dict']['conv4_1.0.weight']
model.conv4_1[0].bias.data = state_dict['state_dict']['conv4_1.0.bias']

model.conv5[0].weight.data = state_dict['state_dict']['conv5.0.weight']
model.conv5[0].bias.data = state_dict['state_dict']['conv5.0.bias']
model.conv5_1[0].weight.data = state_dict['state_dict']['conv5_1.0.weight']
model.conv5_1[0].bias.data = state_dict['state_dict']['conv5_1.0.bias']

model.conv6[0].weight.data = state_dict['state_dict']['conv6.0.weight']
model.conv6[0].bias.data = state_dict['state_dict']['conv6.0.bias']
model.conv6_1[0].weight.data = state_dict['state_dict']['conv6_1.0.weight']
model.conv6_1[0].bias.data = state_dict['state_dict']['conv6_1.0.bias']

总结

本博客只是提供了一种思路,不一定是最好的,但是目前对我管用。

  • 6
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值