参考博客:
https://blog.csdn.net/weixin_40522801/article/details/106563354
https://blog.csdn.net/yangwangnndd/article/details/100207686
函数定义:
load_state_dict(state_dict, strict=True)
作用:
使用 state_dict 反序列化模型参数字典。用来加载模型参数。将 state_dict 中的 parameters 和 buffers 复制到此 module 及其子节点中。
概况:给模型对象加载训练好的模型参数,即加载模型参数
关于state_dict:
在PyTorch中,一个torch.nn.Module模型中的可学习参数(比如weights和biases),模型的参数通过model.parameters()获取。而state_dict就是一个简单的Python dictionary,其功能是将每层与层的参数张量之间一一映射。
注意,只有包含了可学习参数(卷积层、线性层等)的层和已注册的命令(registered buffers,比如batchnorm的running_mean)才有模型的state_dict入口。优化方法目标(torch.optim)也有state_dict,其中包含的是关于优化器状态的信息和使用到的超参数。
因为state_dict目标是Python dictionaries,所以它们可以很轻松地实现保存、更新、变化和再存储,从而给PyTorch模型和优化器增加了大量的模块化(modularity)。
使用示例:
model.load_state_dict(torch.load('pose_dekr_hrnetw32_coco.pth'), strict=True)
官方函数说明:
Copies parameters and buffers from state_dict into this module
and its descendants. If strict is True, then the keys of
state_dict must exactly match the keys returned by this
module’s state_dict() function.
从函数接收的参数state_dict中将参数和缓冲拷贝到当前这个模块及其子模块中.
如果函数接受的参数strict是True,那么state_dict的关键字必须确切地严格地和
该模块的state_dict()函数返回的关键字相匹配.
Parameters 参数:
state_dict (dict) – a dict containing parameters and persistent buffers.
state_dict (字典类型) – 一个包含参数和持续性缓冲的字典
往往是pytorch模型pth文件
strict (布尔类型, 可选) – 该参数用来指明是否需要强制严格匹配,
即:state_dict中的关键字是否需要和该模块的state_dict()方法返回的关键字强制严格匹配.默认值是True
Returns 返回:
返回类型:NamedTuple with missing_keys and unexpected_keys fields
missing_keys is a list of str containing the missing keys
missing_keys是一个字符串的列表,该列表包含了所有缺失的关键字.
unexpected_keys is a list of str containing the unexpected keys
unexpected_keys是一个字符串的列表,该列表包含了意料之外的关键字,