🤵 Author :Horizon John
✨ 编程技巧篇:各种操作小结
🎇 机器视觉篇:会变魔术 OpenCV
💥 深度学习篇:简单入门 PyTorch
🏆 神经网络篇:经典网络模型
💻 算法篇:再忙也别忘了 LeetCode
错误提示
RuntimeError: Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var"
后面接着一大堆网络结构相关的东西
错误原因
load_state_dict
中 strict 参数默认是 True :
表示预训练模型的层和自己定义的网络结构层严格对应相等(如层名和维度)
所以当我们修改了网络结构后,如果strict为True的时候就会报错
def load_state_dict(self, state_dict, strict=True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.
Arguments:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
解决方案
将 load_state_dict
中 strict 参数设置为 False
修改前:
model.load_state_dict(torch.load(model_path))
修改后:
model.load_state_dict(torch.load(model_path), False)
🈺 喜欢的 留个 关注
、 加 点赞
哦 ~