更改预训练权重的部分参数形状(训练自己数据集类别发生变化)

1.在使用一些开源的预训练权重的时候,会碰到类别不匹配的现象,因为预训练数据集的类别数量和我们进行微调的数据集中的类别数量往往不是相等的。

通过加载预训练权重并且更改预训练权重中的分类头的参数形状,来适应微调数据集的类别数

import torch
import numpy as np

# 指定旧模型权重文件的路径
old_pt_path = "/home/notebook/data/group/weicongcong/vit-pytorch/hugging_vit/pytorch_model.bin"
# 指定新模型权重文件的保存路径
new_pt_path = "/home/notebook/data/group/weicongcong/vit-pytorch/hugging_vit/new_pytorch_model.bin"

# 旧的类别数量(原模型的分类层输出大小)
old_cls_len = 1000
# 新的类别数量(修改后模型的分类层输出大小)
new_cls_len = 7

# 分类层输入的特征数量
in_channels = 768

# 加载旧模型的状态字典到 CPU 内存中
old_state_dict = torch.load(old_pt_path, map_location='cpu')
# 创建一个新的状态字典用于存放修改后的参数
new_state_dict = {}

# 遍历旧状态字典中的每个键值对
for key in old_state_dict:
    print(key)  # 打印当前处理的键名,通常用于调试
    if key == "classifier.weight":  # 检查是否为分类层的权重
        print(old_state_dict[key].shape)  # 打印旧权重的形状
        # 截取权重矩阵的前7行,使其与新类别数相匹配
        new_state_dict[key] = old_state_dict[key][:new_cls_len, :]
        print(new_state_dict[key].shape)  # 打印新权重的形状以确认更改
    elif key == "classifier.bias":  # 检查是否为分类层的偏置
        print(old_state_dict[key].shape)  # 打印旧偏置的形状
        # 截取偏置向量的前7个元素,使其与新类别数相匹配
        new_state_dict[key] = old_state_dict[key][:new_cls_len]
    else:
        # 对于其他不需要修改的参数,直接复制到新状态字典中
        new_state_dict[key] = old_state_dict[key]

# 将修改后的状态字典保存到新的文件路径
torch.save(new_state_dict, new_pt_path)

直接加载改变预训练权重的方法比在训练的时候加载模型的方法看起来更加简单直观

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值