问题描述:AODNet在计算ssim时,需要从test_loader中将测试集清晰的图像和去雾后的图像输入给ssim方法进行计算,从test_loader中加载图像对的原始代码如下:
for (img_orig, img_dehaze) in enumerate(test_loader):
ssim_test = ssim(img_orig, img_dehaze)
报错:
Traceback (most recent call last):
File "/home/PycharmProjects/z_fog/AOD-Net/dehaze.py", line 93, in <module>
ssim_test = ssim(img_orig, img_dehaze)
File "/home/PycharmProjects/z_fog/AOD-Net/metrics.py", line 56, in ssim
img1=torch.clamp(img1,min=0,max=1)
TypeError: clamp(): argument 'input' (position 1) must be Tensor, not int
问题分析:
报错内容指出输入ssim的img1不是tensor,而是int,说明从test_loader中读取的图像有误。
打印test_loader的长度和test_list列表的内容,均正确,因此分析test_loader中的数据还是正确的,是加载出来的时候出错。在enumerate循环中打印img_orig和img_dehaze,发现img_dehaze打印出来是代表图像的tensor,而img_orig打印出来时图像的序号,即0~1049.
因此确定是enumerate用法错误,查询用法如下:
来源: Python-enumerate() 函数 - OliYoung - 博客园
enumerate函数有两个输入,第一个参数是序号,第二个是列表或元组等支持迭代的对象。
对于本程序来说,支持迭代的对象是图像对 (img_orig, img_dehaze) ,原程序 for (img_orig, img_dehaze) in enumerate(test_loader):中少了第一个参数,因此把img_orig就当做第一个参数序号了,测试集共1050张图,序号是0~1049,因此打印img_orig出来的就是0~1049.
问题解决:
给enumerate添加第一个参数,改为:
for iter_test, (img_orig, img_dehaze) in enumerate(test_loader):