我需要将预训练权重的ResNet50网络替换成Mobilenetv2网络
1.先通过detectron2自己的权重转换代码
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
import pickle as pkl
import sys
import torch
import torchvision.models as models
"""
Usage:
# download one of the ResNet{18,34,50,101,152} models from torchvision:
wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O r50.pth
# run the conversion
./convert-torchvision-to-d2.py r50.pth r50.pkl
# Then, use r50.pkl with the following changes in config:
MODEL:
WEIGHTS: "/path/to/r50.pkl"
PIXEL_MEAN: [123.675, 116.280, 103.530]
PIXEL_STD: [58.395, 57.120, 57.375]
RESNETS:
DEPTH: 50
STRIDE_IN_1X1: False
INPUT:
FORMAT: "RGB"
"""
if __name__ == "__main__":
# input = sys.argv[1]
#
# obj = torch.load(input, map_location="cpu")
#
# newmodel = {}
# for k in list(obj.keys()):
# old_k = k
# if "layer" not in k:
# k = "stem." + k
# for t in [1, 2, 3, 4]:
# k = k.replace("layer{}".format(t), "res{}".format(t + 1))
# for t in [1, 2, 3]:
# k = k.replace("bn{}".format(t), "conv{}.norm".format(t))
# k = k.replace("downsample.0", "shortcut")
# k = k.replace("downsample.1", "shortcut.norm")
# print(old_k, "->", k)
# newmodel[k] = obj.pop(old_k).detach().numpy()
#
# res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True}
#
# with open(sys.argv[2], "wb") as f:
# pkl.dump(res, f)
# if obj:
# print("Unconverted keys:", obj.keys())
####################################
input = sys.argv[1]
obj = torch.load(input, map_location="cpu")
newmodel = {}
for k in list(obj.keys()):
old_k = k
k = "backbone." + k
print(old_k, "->", k)
newmodel[k] = obj.pop(old_k).detach().numpy()
res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True}
torch.save(newmodel, 'mnv2.pth')
with open(sys.argv[2], "wb") as f:
pkl.dump(res, f)
if obj:
print("Unconverted keys:", obj.keys())
2.再将权重中的骨干给替换掉
import collections
import torch
decode_path = r'preweights/fastinst_R50_ppm-fpn_x1_576_34.9.pth'
mnv2_path = r'preweights/mnv2.pth'
# 载入模型
decode = torch.load(decode_path, map_location=torch.device('cpu'))
mnv2 = torch.load(mnv2_path)
# 可以看到模型的类型是字典
print(type(mnv2))
# 模型参数的数量
print(len(mnv2))
mydict={}
for k in decode['model'].keys():
if "backbone" not in k:
mydict[k] = decode['model'][k]
for k in mnv2.keys():
mydict[k] = mnv2[k]
for k in mydict.keys():
print(k) # 可以看到每套参数的key,也就是该套参数的名字
print("替换model")
decode["model"] = mydict
torch.save(decode,"new_fastInst_mnv2.pth")