深度学习模型tensor维度对不上怎么办

本文分享了一位开发者在测试深度学习模型时遇到的输入维度不匹配问题,经过排查图片读取、通道数等环节,最终发现是forward函数中针对batch处理的代码导致的。通过在单个样本输入前添加unsqueeze(0)解决了问题,提醒读者注意模型输入格式对训练和测试的影响。
摘要由CSDN通过智能技术生成

深度学习模型tensor维度对不上是一个非常常见且有时比较难排查的现象。之所以难排查是因为报错信息和真实的错误原因之间的联系往往并不紧密,很难仅仅从PyTorch给出的数字上的信息判断错误在何处。

笔者在一个训练好的模型上测试单个样本时出现了这个问题,一度排查了图片读取、图片通道数等问题,最后发现原因在于forward部分的代码是针对batch编写的,所以输入的格式是四维的BxCxHxW,而测试单个样本时输入是三维的,所以只需要一行img=img.unsqueeze(0)就解决了这个浪费了一上午 + 一下午的问题。。。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值