pytorch:如何修改加载了预训练权重的模型的输入或输出--(原权重文件修改参数)

在使用pytorch的过程中,我们往往会使用官方发布的预训练模型,并在此基础上训练自己的模型。为了适配训练数据,有时候需要局部修改这类预训练模型的结构,本文将分别以修改输入的通道数和输出的分类数为例,讲解一种通用的方法来修练模型的结构。

加载模型

在修改之前需要加载预训练模型,这里以mobilenet v2为例

import torchvision.models as models
model = models.mobilenet_v2(pretrained=True)

修改模型结构

修改之前需要查看模型结构

print(model)

则可以看到一长串的模型输入,这里因为篇幅原因只截了开头部分和结尾部分
在这里插入图片描述
仔细观察输出的模型结构,卷积层(特别是括号中的features,classifier,(0) 等标志性词可以得知模型的第一层为:

model.features[0][0] = Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

分类层为

model.classifier[1] = Linear(in_features=1280, out_features=1000, bias=True)

通过这些信息接下来就可以修改输入/输出了

修改模型输入

#输入为单通道

model.features[0][0] = Conv2d(1 ,32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

#修改预训练模型权重的结构,使得模型可以使用修改后的预训练模型权重

#加载预训练模型
pre_trained_model = models.mobilenet_v2(pretrained=True)
#获取预训练权重文件的字典
pretrained_dict = pre_trained_model.state_dict()
#打印权重信息
print(pretrained_dict.items())
'''
打印显示为:
dict_items([('features.0.weight', tensor([[[[-2.8656e-03,  4.1653e-02,  5.7146e-02,  ...,  5.2015e-03,
           -5.7198e-03, -2.3688e-02],...
这里可以看到第一层的key为‘features.0.weight’,接下来就可以通过这个名称访问pretrained_dict中对应的权重
'''
#获取第一层权重
layer1 = pretrained_dict['features.0.0.weight']
#创建一个新的张量,这个张量后面将替代pretrain_dict中的第一层,以适应修改为单通道的模型
new = torch.zeros(32,1,3, 3)
#这里修改第一层
for i,output_channel in enumerate(layer1):
	# Grey = 0.299R + 0.587G + 0.114B, 这个公式参考了RGB图转灰度图的方式
    new[i] = 0.299 * output_channel[0] + 0.587 * output_channel[1] + 0.114 * output_channel[2]
#现在第一层的shape为(32,1,3,3)了
pretrained_dict['features.0.0.weight'] = new 
#修改模型结构
model.features[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
model.load_state_dict(pretrained_dict)

修改模型输出

这里修改输出的方式不像修改输入这么繁琐

fc_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(fc_features, 2)
  • 7
    点赞
  • 62
    收藏
    觉得还不错? 一键收藏
  • 11
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值