如何在Pytorch中载入部分权重

如何在Pytorch中载入部分权重

很多时候,比如我们不想训练模型,想用预训练模型来进行测试,或者加载预训练模型来训练。但是预训练模型中网络权重已经训练好了,是一个整体。比如resnet网络默认输出的类别是1000类,但是我们现在的数据是5类,那怎么才能用上这个预训练模型呢?
答案是:我们可以载入部分权重

方法有两种

方法一

我们知道网络是由很多层堆叠起来的,默认是1000类的残差网络,前面的多层卷积层是不用修改的,但是最后一层的全连接层不满足要求,需要根据自己类别修改。
在初始话的时候不传入类别参数,直接修改全连接层的结构

    net = resnet34()
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)
    print('net:', net)

在这里插入图片描述
可以看出,最后的全连接层已经修改成5了。
其实这里我有个疑问,这样只是修改了结构,但是权重不就还是之前的权重吗?

方法二

在模型初始化话的时候,传入自己的类别,这样网络的结构肯定是没问题的,最后全连接层输出是5。但是权重参数是按照1000训练的,加载权重时不要加载全连接层相关的参数即可。

    net = resnet34(num_classes=5)
    pre_weights = torch.load(model_weight_path, map_location=device)

    del_key = []
    for key, _ in pre_weights.items():
        if "fc" in key:
            del_key.append(key)
            weight_t = net.state_dict()[key].numpy()
            print(key, ":", weight_t)
    
    for key in del_key:
        del pre_weights[key]
    
    missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
    print("[missing_keys]:", *missing_keys, sep="\n")
    print("[unexpected_keys]:", *unexpected_keys, sep="\n")

完整代码如下

import os
import torch
import torch.nn as nn
from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)

    # option1
    net = resnet34()
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)



    # option2
    # net = resnet34(num_classes=5)
    # pre_weights = torch.load(model_weight_path, map_location=device)

    # del_key = []
    # for key, _ in pre_weights.items():
    #     if "fc" in key:
    #         del_key.append(key)
    #         weight_t = net.state_dict()[key].numpy()
    #         print(key, ":", weight_t)
    
    # for key in del_key:
    #     del pre_weights[key]
    
    # missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
    # print("[missing_keys]:", *missing_keys, sep="\n")
    # print("[unexpected_keys]:", *unexpected_keys, sep="\n")
    

    

if __name__ == '__main__':
    main()

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值