1.在使用一些开源的预训练权重的时候,会碰到类别不匹配的现象,因为预训练数据集的类别数量和我们进行微调的数据集中的类别数量往往不是相等的。
通过加载预训练权重并且更改预训练权重中的分类头的参数形状,来适应微调数据集的类别数
import torch
import numpy as np
# 指定旧模型权重文件的路径
old_pt_path = "/home/notebook/data/group/weicongcong/vit-pytorch/hugging_vit/pytorch_model.bin"
# 指定新模型权重文件的保存路径
new_pt_path = "/home/notebook/data/group/weicongcong/vit-pytorch/hugging_vit/new_pytorch_model.bin"
# 旧的类别数量(原模型的分类层输出大小)
old_cls_len = 1000
# 新的类别数量(修改后模型的分类层输出大小)
new_cls_len = 7
# 分类层输入的特征数量
in_channels = 768
# 加载旧模型的状态字典到 CPU 内存中
old_state_dict = torch.load(old_pt_path, map_location='cpu')
# 创建一个新的状态字典用于存放修改后的参数
new_state_dict = {}
# 遍历旧状态字典中的每个键值对
for key in old_state_dict:
print(key) # 打印当前处理的键名,通常用于调试
if key == "classifier.weight": # 检查是否为分类层的权重
print(old_state_dict[key].shape) # 打印旧权重的形状
# 截取权重矩阵的前7行,使其与新类别数相匹配
new_state_dict[key] = old_state_dict[key][:new_cls_len, :]
print(new_state_dict[key].shape) # 打印新权重的形状以确认更改
elif key == "classifier.bias": # 检查是否为分类层的偏置
print(old_state_dict[key].shape) # 打印旧偏置的形状
# 截取偏置向量的前7个元素,使其与新类别数相匹配
new_state_dict[key] = old_state_dict[key][:new_cls_len]
else:
# 对于其他不需要修改的参数,直接复制到新状态字典中
new_state_dict[key] = old_state_dict[key]
# 将修改后的状态字典保存到新的文件路径
torch.save(new_state_dict, new_pt_path)
直接加载改变预训练权重的方法比在训练的时候加载模型的方法看起来更加简单直观