pytorch中ToTensor()函数失效案例解析

问题描述
  • 在推理模型时,发现模型在训练过程中指标很高,但是当我测试单张图像时,效果很差。
  • 于是,逐步调试,发现了问题所在,在于训练过程中,图像的前处理时,transforms中的ToTensor()函数并没有将图像像素值归一到[0., 1.0]之间。
  • 这就导致了模型是在图像像素值没有归一到[0., 1.0]之间,就送入到了网路中。而在推理时,采用的是numpy作为前处理,正确地将图像像素值归一到了[0., 1.0]之间,从而导致训练和推理两个过程对图像的前处理不一致,从而效果很差。
问题复现
  • ToTensor()函数失效原因是因为在送到ToTensor()函数前,对图像做了格式转换,导致没能归一化到[0, 1]之间。
  • 可复现代码:
    from torchvision import transforms
    import cv2
    
    to_tensor = transforms.ToTensor()
    
    img = cv2.imread('rain_19_16.jpg')
    img = img.astype('float32')
    print(img.shape)
    
    im = to_tensor(img)
    print(im.shape)
    print(im.max())
    
    # 执行结果
    # (450, 600, 3)
    # torch.Size([3, 450, 600])
    # tensor(255.)
    
  • 输出结果为:tensor(255.)。说明经过ToTensor()之后,并没有归一化到[0, 1]之间,只是做了轴的转换。
问题原因
  • 原因就是上述代码的第7行img = img.astype('float32')。类型从uint8转为float32,导致未能成功归一化。
  • ToTensor()源码可以一探究竟,源码位置torchvision/transforms/functional.py的L52行to_tensor()函数
    def to_tensor(pic):
        """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    
        See ``ToTensor`` for more details.
    
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
    
        Returns:
            Tensor: Converted image.
        """
        if not(F_pil._is_pil_image(pic) or _is_numpy(pic)):
            raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
    
        if _is_numpy(pic) and not _is_numpy_image(pic):
            raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
    
        if isinstance(pic, np.ndarray):
            # handle numpy array
            if pic.ndim == 2:
                pic = pic[:, :, None]
    
            img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
            # backward compatibility
            if isinstance(img, torch.ByteTensor):
                return img.float().div(255)
            else:
                return img
    
  • 如果输入图像格式为uint8,则经过上述代码的23行:img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()之后,img就是torch.ByteTensor的实例;
  • 如果输入图像格式为float32,img就不是torch.ByteTensor的实例。
  • 至于为什么会有不同,torch.from_numpy的这个函数源码就在c++里了,看不到了。
解决方案
  • 送入ToTensor()前,保证图像格式是uint8即可。
  • 8
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值