pth和pth.tar???checkpoint的组成???state_dict???key???unpected key??? missing key???

最近复现一篇论文,找作者要了训练的checkpoint,他给了我一个网站,里面都是tar.pth的文件。
从我学深度学习开始见过的checkpoint全都是pth后缀的,导致我这次整尬了,贼尴尬,这个tar.pth文件其实也是checkpoint的一种。
它包含的信息更多罢了,我没搞明白,我就问作者,然后作者就把我当白痴了,c了,故写下此篇笔记,帮助更多的人少走弯路55555

一. pth和pth.tar的区别

.pth.pth.tar 文件在 PyTorch 中通常是用来保存模型状态的,这些状态包括模型的参数(权重和偏置)。它们之间的区别主要在于命名惯例,而不是功能:

  1. .pth 文件:这是最常见的文件后缀,通常用来保存模型的状态字典(state_dict),也就是模型的权重和偏置参数。

  2. .pth.tar 文件:这个后缀通常用来表示该文件是一个打包文件(虽然实际上并不需要解压)。它可以包含除了模型参数外的其他信息,比如训练过程中保存的元数据(如训练的 epoch 数、模型架构信息等)。

二. 模型文件可保存的信息:

至于模型文件中可以保存的信息,主要有以下几种:

  1. 模型结构(Architecture):保存整个模型的结构信息,这样在加载模型时可以直接恢复模型,而不需要重新定义模型结构。这通常用 .pth 文件或 .pth.tar 文件来实现,但要注意,保存整个模型结构会使文件体积较大。

  2. 模型权重(Weights):这是最常见的保存内容,保存模型的参数。通常使用 model.state_dict() 来提取和保存。

  3. 优化器状态(Optimizer State):在继续训练时,保存优化器的状态(如动量、学习率等)是非常有用的。它可以帮助在训练中断后继续之前的训练过程。

  4. 其他元数据(Metadata):比如训练的 epoch 数、最优精度等,可以帮助记录训练的进展和结果。

总结来说,.pth.pth.tar 文件在 PyTorch 中本质上都是用于保存模型的,只不过 .pth.tar 有时包含更多元数据,适用于更复杂的训练和保存需求。

三. 什么是model.state_dict包含什么内容?

准确地说,状态字典(state_dict)主要保存模型中所有可学习参数,即那些在训练过程中通过反向传播算法被更新的参数,如权重和偏置。同时,状态字典也会保存某些非可学习参数,这些参数不会通过反向传播进行更新,但在推理和模型计算过程中依然起到关键作用。例如:

  1. 可学习参数

    • 权重(Weights):连接层之间的权值矩阵。
    • 偏置(Biases):各层的偏置项。
    • BatchNorm 层中的可学习参数:如缩放因子(weight)和偏移因子(bias)。
  2. 非可学习参数

    • BatchNorm 层中的统计参数:如 running_meanrunning_var,这些参数是在训练过程中累计得到的,并用于测试时的归一化操作,但不会被优化器更新。

此外,某些模型中还可能包含其他与模型结构或运行有关的参数,这些参数也会被保存到状态字典中,但这些参数通常也不会在训练过程中被更新。

总结

状态字典确实会保存所有可学习的参数和某些非可学习但在模型运行中需要使用的参数。因此,状态字典是全面反映模型当前状态的关键部分,包含了模型在推理时所需的全部信息。

四.key是啥?missing key和unpected key是啥意思?如何改正?

你遇到的错误信息主要涉及两个方面:unexpected keymissing keys。它们都和加载模型的 state_dict 有关。
前提:
key 是一个字符串,表示模型中的某个参数的位置或名称。例如,layer1.0.conv1.weight 表示模型中 layer1 下第一个卷积层(conv1)的权重参数。每个 key 对应的参数是模型计算过程中的一个最小单位,但从概念上讲,这些 key 并不是参数本身,而是参数的**“名称”“地址”**。
value :key 对应的值是一个张量,存储了实际的参数数据,如权重矩阵或偏置向量。这个张量是通过训练学习得到的。key 对应的张量(value)才是真正的模型参数。它们是经过反向传播更新的数值。

1. unexpected keymissing keys 的含义

  • unexpected key in source state_dict:当你加载模型的 state_dict 时,PyTorch 发现了你加载的 state_dict 中存在一些模型当前结构中不需要的参数。这意味着这些参数在你定义的模型中没有对应的部分,可能是因为你加载的 state_dict 是针对不同的模型架构或版本保存的。

  • missing keys in source state_dict:这个错误说明在你加载的 state_dict 中,缺少一些当前模型架构所需要的参数。也就是说,当前模型中定义了一些层或参数,但是在加载的 state_dict 中没有找到对应的权重或偏置。这通常是因为你加载的 state_dict 是为其他模型或模型的不同版本保存的。

2. 和 state_dict 的关系

state_dict 是一个保存了模型所有参数的字典,包括可训练参数(如权重和偏置)以及某些非可训练参数(如 BatchNorm 的统计参数)。当你使用 model.load_state_dict() 加载这个字典时,PyTorch 会检查字典中的键(通常是参数的名称)是否和模型中的参数对应。如果字典中的键在模型中找不到对应的参数位置,或者模型中定义的参数在字典中找不到,就会产生 unexpected keymissing keys 错误。

3. 关于 layer 的问题

layer 相关的 key 在 missing keys 中,意味着这些层对应的参数(如权重和偏置)没有成功从 state_dict 中加载到你的模型中。这可能是因为:

  • 加载的 state_dict 是针对不同版本的模型(如修改了层的名称或结构)。
  • 定义的模型结构与保存 state_dict 的模型结构不完全匹配(例如,层的数量或名称不同)。

4. 解决方法

如果确定模型架构和 state_dict 是匹配的,可能的问题是:

  • 手动调整模型的架构:根据 state_dict 中的键名,调整你模型中对应层的名称,以匹配 state_dict 的结构。

  • 严格模式关闭:在加载 state_dict 时,你可以通过将 strict=False 传递给 load_state_dict 方法,以忽略这些不匹配的键:

    model.load_state_dict(checkpoint['state_dict'], strict=False)
    

    这会加载所有可以匹配的参数,并忽略不匹配的部分,但你需要确认这样做不会影响模型的性能。

  • 检查预训练模型的来源:如果你使用的预训练模型与当前模型架构不一致,可以考虑找到与当前架构更匹配的预训练模型。

总结

missing keys 错误表示模型结构中需要的参数在加载的 state_dict 中找不到,导致对应的层没有成功加载参数。这通常是由于模型结构和 state_dict 不匹配引起的。你可以通过调整模型架构或加载方式来解决这个问题。

问题:

  1. unexpected key in source state_dict: model, last_iter …
  2. missing keys in source state_dict: conv1.weight, bn1.weight, bn1.bias, bn1.running_mean, bn1.running_var, …

###补充: missing keys的原因!!!!!!!!!!!!!!!!!!!
1.结构不匹配
2.pth和pth.tar不一样
3.需要的key名字不一样,需要更换key的名字比如:

import torch


model = torch.load("/home/guo/gv-benchmark/checkpoint/INTERN-mnb4-Up-G-dbn-cb11deba1.pth.tar", "cpu")


model_keys = list(model["state_dict"].keys())


for key in model_keys:
    model['state_dict']['model.' + key] = model['state_dict'][key].clone()
    del model['state_dict'][key]


torch.save(model, "/home/guo/gv-benchmark/checkpoint/INTERN-mnb4-Up-G-dbn-cb11deba1-new.pth.tar")

四. Conclusion

这次遇到的问题,总体来讲就是我在windows系统下载pth.tar的文件后,我直接解压,第一次解压时估计用的别的解压软件,解压出了pth文件,但是要知道pth.tar文件就是一个完整的权重文件了,你解压以后torch也没法识别了呀,然后导致了第 三 节的问题。
导致此问题的原因是:我把pth文件从torch.load导入,这个时候他按照pth文件的代码去导入的,但是因为我是pth.tar解压成的pth,所以这个pth实际上是包含了pth.tar中额外的一些例如模型架构的数据的,所以就导致,出现unpected key…
另外,由于pth文件的解压后出现key肯定也不完整了,出现了点错误呗,导致missing key啥的。

最后,我就直接用pth.tar文件作为预训练权重,torch.load直接load pth.tar就啥问题都没了。就这么简单55555555555

仅以此篇文章记录checkpoint的一些小知识
fighting
G_PP

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值