pytorch保存模型后加载模型遇到的大坑

保存模型

  • 保存模型的方式主要有两种, 一种是全量保存,另一种是只保存网络结构的参数(注意,不保存网络结构,只保存参数,所以在加载模型的时候需要先设置好一个模型网络)
#1.保存整个网络
torch.save(model_object, 'model.pth')
#1.1加载参数
model = torch.load('model.pth')


#2.保存参数
torch.save(model_object.state_dict(), 'params.pth')
#2.1加载模型
model_object=model()
model_object.load_state_dict(torch.load('params.pth'))
  • 针对第二种方法,当你使用上述代码来调用保存的模型时,你会发现,模型效果非常差!离谱!明明保存好模型了,调用的代码也不报错,也能顺利运行,可为什么效果这么差呢?!我为此反复debug,后来不经意间百度了一下为何pytorch保存的模型测试效果这么差,才发现原来这是一个巨坑。
  • 下面是借鉴的网上的资料,期间看了很多博客,发现都是模棱两可,便结合很多资料进行总结,得到了下述结果。
  • 首先纠正上述代码,再讲明其中缘由:
#2.保存参数
torch.save(model_object.state_dict(), 'params.pth')
#2.1加载模型
model_object=model()
model_object.load_state_dict(torch.load('params.pth'))
model.eval()#制定model.eval()固定dropout和BN层。
  • 效果对比:
    之前:
    在这里插入图片描述
    例子是一个图像分割网络,左边是训练集,中间是训练集标签,右边是调用模型的效果,可见非常差,和标签差了太多!
    加上model.eval()之后:
    在这里插入图片描述
    效果好了很多,这才是正常的测试效果。

model.eval()

  • 仅仅多了如此一行,为何有奇效?

pytorch中model.eval()的作用
问题描述:
torch.onnx.export()导出onnx模型后,利用onnxruntime加载onnx模型后,其输出结果与原始.pth模型的输出结果之间存在很大的差距;通过拆分网络结构,定位到nn.BatchNorm2d()层导致;
Batch Normalization和Dropout
Batch Normalization
其作用对网络中间的每层进行归一化处理,并且使用变换重构(Batch Normalization Transform)保证每层提取的特征分布不会被破坏。训练时是针对每个mini-batch的,但是测试是针对单张图片的,即不存在batch的概念。由于网络训练完成后参数是固定的,每个batch的均值和方差是不变的,因此直接结算所有batch的均值和方差。所有Batch Normalization的训练和测试时的操作不同。
Dropout
其作用克服Overfitting,在每个训练批次中,通过忽略一半的特征检测器,可以明显的减少过拟合现象。
model.train()和model.eval()
train()
启用 BatchNormalization 和 Dropout
eval()
不启用 BatchNormalization 和 Dropout,保证BN和dropout不发生变化,pytorch框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层影响结果。
问题解决办法
在利用原始.pth模型进行前向推理之前,一定要先进行model.eval()操作,不启用 BatchNormalization 和 Dropout。

  • 100
    点赞
  • 178
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 26
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 26
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CtrlZ1

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值