迁移学习加载改进后的模型参数

该博客介绍了如何在YOLOv4算法基础上进行模型改进并利用迁移学习初始化参数。通过加载预训练权重文件,使用`torch.load()`将模型字典与预训练字典匹配,并根据形状进行筛选,确保只更新结构相同的参数。文中提供了`transfer_model`函数实现这一过程,并引用了其他开发者的工作作为参考。
摘要由CSDN通过智能技术生成

pytorch基于改进模型使用迁移学习加载参数

前言

由于发paper或多或少的需要一定创新性,本篇文章是基于YOLOv4算法想对模型结构进行改进,这里仅仅是做一个小样。调整主干模型并不对主干模型的输出进行改变,因此只需要将修改后模型的参数初始化。最后正常对模型进行训练即可。
在这里插入图片描述

本文也许并不一定可以帮到你,但是如果你也想使用迁移学习的方式去初始化模型的参数,那么我相信下面的文章或多或少对你有一些启发。
同时很感谢参考文章的两位作者。

正文

今天在考虑如何对YOLOv4模型进行改进如何考虑使用迁移学习的方式为没有变化的结构提供参数,目前自己大概了解到torch.load()方法是将文件加载成字典的格式。下面是Github上面一国内大神(Bubbliiiing)的yolov4代码pytorch版本进行迁移学习加载权重文件的方式。

    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_path, map_location=device)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

但是我对YOLOv4的模型进行改进,于是考虑简单的对网络结构进行改变,但是由于又想借助迁移学习的方式进行参数初始化,借鉴了https://blog.csdn.net/guyuealian/article/details/94181896
这一位大神的代码,实现了迁移学习初始化权重。

def transfer_model(pretrained_file, model):
    pretrained_dict = torch.load(pretrained_file)  # get pretrained dict
    model_dict = model.state_dict()  # get model dict
    # 在合并前(update),需要去除pretrained_dict一些不需要的参数
    pretrained_dict = transfer_state_dict(pretrained_dict, model_dict)
    model_dict.update(pretrained_dict)  # 更新(合并)模型的参数
    model.load_state_dict(model_dict)
    return model


def transfer_state_dict(pretrained_dict, model_dict):
    # state_dict2 = {k: v for k, v in save_model.items() if k in model_dict.keys()}
    state_dict = {}
    for k, v in pretrained_dict.items():
        if k in model_dict.keys():
            if v.shape==model_dict[k].shape:
                state_dict[k] = v
            else:
                print(k,'shape dismatch')
        else:
            print("Missing key(s) in state_dict :{}".format(k))
    return state_dict
model = transfer_model("model_data/yolo4_voc_weights.pth",model)

参考文章
[1]:https://github.com/bubbliiiing/yolov4-pytorch
[2]: https://blog.csdn.net/guyuealian/article/details/94181896

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值