参数名的映射,小心使用strict=False

文章讲述了在PyTorch中加载vgg16预训练模型时遇到参数名不匹配的问题,通过设置`strict=False`避免严格检查,或者通过修改参数名映射来解决。作者提供了将参数名中.module.前缀删除的方法,以及手动创建参数名映射字典的方式,成功加载模型权重。
摘要由CSDN通过智能技术生成

从vgg16-397923af.pth里读取的数值应该和加载预训练模型后model.load_state_dict参数一致。
而我的不一致!
原因:在载入参数到模型键值的不匹配,所以使用了strict=False。
解决办法

  • 进行参数名的映射,将不匹配的参数名进行对应
  • 看到另一种方法——将即将要载入的参数中不匹配的键多余部分,‘module.’删除就可匹配【未尝试】
params = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()} 	#替换将要载入的参数的键的不匹配部分

#	进行参数名的映射
import numpy as np
import torch
import torchvision
from torchvision import models
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from src import fcn_resnet50, resnet50
from model_fcn8s import VGG, fcn_vgg16

pretrain_backbone = True
a = ["layer1.0.weight", "layer1.0.bias", "layer1.2.weight", "layer1.2.bias", "layer2.0.weight", "layer2.0.bias", "layer2.2.weight", "layer2.2.bias", "layer3.0.weight", "layer3.0.bias", "layer3.2.weight", "layer3.2.bias", "layer3.4.weight", "layer3.4.bias", "layer4.0.weight", "layer4.0.bias", "layer4.2.weight", "layer4.2.bias", "layer4.4.weight", "layer4.4.bias", "layer5.0.weight", "layer5.0.bias", "layer5.2.weight", "layer5.2.bias", "layer5.4.weight", "layer5.4.bias"]
b = ["features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias"]
struct = [(2, 64), (2, 128), (3, 256), (3, 512), (3, 512)]
backbone = VGG(num_classes=21, struct=struct)
if pretrain_backbone is True:
    weights_dict = torch.load("/home/hyq/hyq/projects/fcn/vgg16-397923af.pth")
    model_dict = {}
    param_mapping = dict(zip(b, a))
    for k, v in weights_dict.items():
        if k not in b:
            continue
        model_dict[param_mapping[k]] = v
    # backbone.load_state_dict(torch.load("/home/hyq/hyq/projects/fcn/vgg16-397923af.pth"), strict=False)
    backbone.load_state_dict(model_dict)
    for name, param in backbone.state_dict().items():
        print(f"{name}: {param}") 
        # print(name, end=' ')

在加载预训练模型处修改
在这里插入图片描述

从vgg16-397923af.pth里读取的数值应该和加载预训练模型后model.load_state_dict参数一致

【pytorch载入模型参数报错以及解决办法,小心使用strict=False】

在 Python 中,`json.loads()` 是将 JSON 格式的字符串转换为 Python 对象的函数。`strict=False` 参数是用来控制解析 JSON 字符串时是否严格按照 JSON 标准进行解析的。 当 `strict=False` 时,`json.loads()` 函数会容忍一些非标准的 JSON 格式,比如单引号代替双引号、不带引号的属性、尾部逗号等,这些在标准的 JSON 中是不被允许的。如果 JSON 字符串中有这些非标准的格式,`json.loads()` 函数会尝试进行容错处理,将它们转换为标准的 JSON 格式,然后再进行解析。 举个例子,假设有一个 JSON 字符串如下: ```json { 'name': 'Alice', 'age': 20, 'address': { 'city': 'Beijing', 'country': 'China', } } ``` 这个 JSON 字符串中,属性使用的是单引号而不是双引号,并且 `address` 对象末尾有一个逗号。如果直接使用 `json.loads()` 进行解析,会抛出 `json.decoder.JSONDecodeError` 异常,因为这不是合法的 JSON 格式。但是如果加上 `strict=False` 参数,就可以容忍这些非标准的格式,比如这样: ```python import json json_string = ''' { 'name': 'Alice', 'age': 20, 'address': { 'city': 'Beijing', 'country': 'China', } } ''' data = json.loads(json_string, strict=False) print(data) ``` 输出结果为: ```python {'name': 'Alice', 'age': 20, 'address': {'city': 'Beijing', 'country': 'China'}} ``` 从输出结果可以看出,`json.loads()` 函数已经将原来的 JSON 字符串转换为了一个 Python 字典对象,并且自动将单引号转换为了双引号,将末尾的逗号去掉了,同时忽略了属性不带引号的问题。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值