解决torch.load中Unexpected key(s) in state_dict的报错问题
问题发现
今天在使用强化学习跑实验的时候,使用到了torch.load_state_dict()
这个函数,从已经保存的pth文件中加载模型权重。但是加载的过程中,出现了Unexpected key(s) in state_dict
的错误。
问题排查
- 我看了看写的代码,发现模型保存和模型加载用到的都是同一份模型文件,按理说应该不会存在问题。
- 再仔细钻研了以下,发现
Unexpected key(s) in state_dict
错误出现的地方是模型中一些常量
的保存,比如说nn.Parameter()的保存。查了下chatgpt,找到了解决方法:只需要在加载文件的时候适用Strict=False
这一行代码使模型能允许加载模型时部分键不匹配。