从vgg16-397923af.pth里读取的数值应该和加载预训练模型后model.load_state_dict参数一致。
而我的不一致!
原因:在载入参数到模型键值的不匹配,所以使用了strict=False。
解决办法:
- 进行参数名的映射,将不匹配的参数名进行对应
- 看到另一种方法——将即将要载入的参数中不匹配的键多余部分,‘module.’删除就可匹配【未尝试】
params = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()} #替换将要载入的参数的键的不匹配部分
# 进行参数名的映射
import numpy as np
import torch
import torchvision
from torchvision import models
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from src import fcn_resnet50, resnet50
from model_fcn8s import VGG, fcn_vgg16
pretrain_backbone = True
a = ["layer1.0.weight", "layer1.0.bias", "layer1.2.weight", "layer1.2.bias", "layer2.0.weight", "layer2.0.bias", "layer2.2.weight", "layer2.2.bias", "layer3.0.weight", "layer3.0.bias", "layer3.2.weight", "layer3.2.bias", "layer3.4.weight", "layer3.4.bias", "layer4.0.weight", "layer4.0.bias", "layer4.2.weight", "layer4.2.bias", "layer4.4.weight", "layer4.4.bias", "layer5.0.weight", "layer5.0.bias", "layer5.2.weight", "layer5.2.bias", "layer5.4.weight", "layer5.4.bias"]
b = ["features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias"]
struct = [(2, 64), (2, 128), (3, 256), (3, 512), (3, 512)]
backbone = VGG(num_classes=21, struct=struct)
if pretrain_backbone is True:
weights_dict = torch.load("/home/hyq/hyq/projects/fcn/vgg16-397923af.pth")
model_dict = {}
param_mapping = dict(zip(b, a))
for k, v in weights_dict.items():
if k not in b:
continue
model_dict[param_mapping[k]] = v
# backbone.load_state_dict(torch.load("/home/hyq/hyq/projects/fcn/vgg16-397923af.pth"), strict=False)
backbone.load_state_dict(model_dict)
for name, param in backbone.state_dict().items():
print(f"{name}: {param}")
# print(name, end=' ')
在加载预训练模型处修改
从vgg16-397923af.pth里读取的数值应该和加载预训练模型后model.load_state_dict参数一致