0.前言:
在深度学习中,官方提供的模型预训练权重文件对于我们的训练有相当大的作用。因为这些权重文件都是官方使用大型数据集在高性能计算机上训练得到的,我们自己很难在大型数据集上训练出自己模型的预训练权重文件,再迁移到自己的数据集上。在官方的模型上,我修改了部分结构——添加了部分操作,也就是说修改后的模型比官方提供的模型更大了。在这里将官方和自己模型相同的部分,加载上官方的权重文件值。
1.生成自己模型权重文件:
这个就很简单啦!在进行模型训练时,训练一轮就能得到模型随机初始化的模型权重文件。这里取名my_weigths.pth
2.模型转化:
权重文件是一个按“字典”形式存储的文件,理解了“字典”这个类型后,就好办多了,首先将两个权重文件读取进来,之后根据自己的权重文件中的“键”将官方提供的权重文件中的相同“键”对应的值赋值过来。
首先看一下自己权重文件的共有多少“键”,分别是什么。
代码:
在这里插入代码片
import torch
my_weights = torch.load(r'E:\Python\temp\my_weights.pth')
print('len=', len(my_weights.keys()))
print('keys():', my_weights.keys())
结果如下,共有503个键值对,部分“键”的名称如下:
将官方的权重文件共有“键”个数,“键”的名称输出
代码:
import torch
Megvi_s = torch.load(r'E:\Python\temp\Megvi_s.pth')
print('len= ', len(Megvi_s['model'].keys()))
print(Megvi_s['model'].keys())
结果如下,共有462个键值对,部分键名称如下,很明显官方的模型键值对只有462个,而自己的却有503个,那就把这462个键值对加载过来用用试试。
import torch
my_weights = torch.load(r'E:\Python\temp\my_weights.pth')
Megvi_s = torch.load(r'E:\Python\temp\Megvi_s.pth')
for item in Megvi_s['model'].keys():
print(item)
if item in my_weights.keys():
my_weights[item] = Megvi_s['model'][item]
# 保存权重文件
torch.save(my_weights,'E:\\Python\\temp\\new_weights.pth')
结果