深度学习中修改了模型结构后,将官方的预训练权重文件转化为自己模型的权重文件

0.前言:

在深度学习中,官方提供的模型预训练权重文件对于我们的训练有相当大的作用。因为这些权重文件都是官方使用大型数据集在高性能计算机上训练得到的,我们自己很难在大型数据集上训练出自己模型的预训练权重文件,再迁移到自己的数据集上。在官方的模型上,我修改了部分结构——添加了部分操作,也就是说修改后的模型比官方提供的模型更大了。在这里将官方和自己模型相同的部分,加载上官方的权重文件值。

1.生成自己模型权重文件:

这个就很简单啦!在进行模型训练时,训练一轮就能得到模型随机初始化的模型权重文件。这里取名my_weigths.pth
在这里插入图片描述

2.模型转化:

权重文件是一个按“字典”形式存储的文件,理解了“字典”这个类型后,就好办多了,首先将两个权重文件读取进来,之后根据自己的权重文件中的“键”将官方提供的权重文件中的相同“键”对应的值赋值过来。
首先看一下自己权重文件的共有多少“键”,分别是什么。
代码:

在这里插入代码片
import torch

my_weights = torch.load(r'E:\Python\temp\my_weights.pth')
   
print('len=', len(my_weights.keys()))          
print('keys():', my_weights.keys())             

结果如下,共有503个键值对,部分“键”的名称如下:
在这里插入图片描述
将官方的权重文件共有“键”个数,“键”的名称输出
代码:

import torch

Megvi_s = torch.load(r'E:\Python\temp\Megvi_s.pth')
  
print('len= ', len(Megvi_s['model'].keys()))        

print(Megvi_s['model'].keys())                           


结果如下,共有462个键值对,部分键名称如下,很明显官方的模型键值对只有462个,而自己的却有503个,那就把这462个键值对加载过来用用试试。
在这里插入图片描述

import torch

my_weights = torch.load(r'E:\Python\temp\my_weights.pth')
Megvi_s = torch.load(r'E:\Python\temp\Megvi_s.pth')

for item in Megvi_s['model'].keys():
    print(item)
    if item in my_weights.keys():
        my_weights[item] = Megvi_s['model'][item]


# 保存权重文件
torch.save(my_weights,'E:\\Python\\temp\\new_weights.pth')

结果
在这里插入图片描述

评论 15
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值