detectron2下预训练权重的替换特征提取网络

我需要将预训练权重的ResNet50网络替换成Mobilenetv2网络
1.先通过detectron2自己的权重转换代码

#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.

import pickle as pkl
import sys

import torch
import torchvision.models as models
"""
Usage:
  # download one of the ResNet{18,34,50,101,152} models from torchvision:
  wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O r50.pth
  # run the conversion
  ./convert-torchvision-to-d2.py r50.pth r50.pkl
  # Then, use r50.pkl with the following changes in config:
MODEL:
  WEIGHTS: "/path/to/r50.pkl"
  PIXEL_MEAN: [123.675, 116.280, 103.530]
  PIXEL_STD: [58.395, 57.120, 57.375]
  RESNETS:
    DEPTH: 50
    STRIDE_IN_1X1: False
INPUT:
  FORMAT: "RGB"
"""

if __name__ == "__main__":
    # input = sys.argv[1]
    #
    # obj = torch.load(input, map_location="cpu")
    #
    # newmodel = {}
    # for k in list(obj.keys()):
    #     old_k = k
    #     if "layer" not in k:
    #         k = "stem." + k
    #     for t in [1, 2, 3, 4]:
    #         k = k.replace("layer{}".format(t), "res{}".format(t + 1))
    #     for t in [1, 2, 3]:
    #         k = k.replace("bn{}".format(t), "conv{}.norm".format(t))
    #     k = k.replace("downsample.0", "shortcut")
    #     k = k.replace("downsample.1", "shortcut.norm")
    #     print(old_k, "->", k)
    #     newmodel[k] = obj.pop(old_k).detach().numpy()
    #
    # res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True}
    #
    # with open(sys.argv[2], "wb") as f:
    #     pkl.dump(res, f)
    # if obj:
    #     print("Unconverted keys:", obj.keys())

    ####################################
    input = sys.argv[1]

    obj = torch.load(input, map_location="cpu")

    newmodel = {}
    for k in list(obj.keys()):
        old_k = k
        k = "backbone." + k
        print(old_k, "->", k)
        newmodel[k] = obj.pop(old_k).detach().numpy()

    res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True}

    torch.save(newmodel, 'mnv2.pth')
    with open(sys.argv[2], "wb") as f:
        pkl.dump(res, f)
    if obj:
        print("Unconverted keys:", obj.keys())

2.再将权重中的骨干给替换掉

import collections
import torch

decode_path = r'preweights/fastinst_R50_ppm-fpn_x1_576_34.9.pth'
mnv2_path = r'preweights/mnv2.pth'
# 载入模型
decode = torch.load(decode_path, map_location=torch.device('cpu'))
mnv2 = torch.load(mnv2_path) 

# 可以看到模型的类型是字典
print(type(mnv2))
# 模型参数的数量
print(len(mnv2))

mydict={}

for k in decode['model'].keys():
    if "backbone" not in k:
        mydict[k] = decode['model'][k]

for k in mnv2.keys():
    mydict[k] = mnv2[k]

for k in mydict.keys():
    print(k)  # 可以看到每套参数的key,也就是该套参数的名字
print("替换model")
decode["model"] = mydict
torch.save(decode,"new_fastInst_mnv2.pth")

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值