使用torch.functions.to_tensor()把数组转换为Tensor时需要注意维度是否正确

深度学习多波段数据训练时,模型一直无法收敛。

torch.functions.to_tensor()函数

下面是to_tensor函数的源码

def to_tensor(pic):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    This function does not support torchscript.

    See :class:`~torchvision.transforms.ToTensor` for more details.

    Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
    if not(F_pil._is_pil_image(pic) or _is_numpy(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))

    if _is_numpy(pic) and not _is_numpy_image(pic):
        raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))

    default_float_dtype = torch.get_default_dtype()


    if isinstance(pic, np.ndarray):
        # handle numpy array
        if pic.ndim == 2:
            pic = pic[:, :, None]

        img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
        # backward compatibility
        if isinstance(img, torch.ByteTensor):
            return img.to(dtype=default_float_dtype).div(255)
        else:
            return img

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=default_float_dtype)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    elif pic.mode == 'F':
        img = torch.from_numpy(np.array(pic, np.float32, copy=False))
    elif pic.mode == '1':
        img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))

    img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
    # put it from HWC to CHW format
    img = img.permute((2, 0, 1)).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.to(dtype=default_float_dtype).div(255)
    else:
        return img

当我们输入的是数组时,而不是PIL的Image对象时,我们要看下面这一段代码:

注意这里有一个维度变换的操作,假如你输入的照片是(C,H,W),那么这里就会变成(W,C,H)

    if isinstance(pic, np.ndarray):
        # handle numpy array
        if pic.ndim == 2:
            pic = pic[:, :, None]
		# 注意这里有一个维度变换的操作,假如你输入的照片是(C,H,W),那么这里就会变成(W,C,H)
        img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
        # backward compatibility
        if isinstance(img, torch.ByteTensor):
            return img.to(dtype=default_float_dtype).div(255)
        else:
            return img

我先说一下我遇到这个问题的始末,我在做多通道图像训练模型时,模型的数据读取部分是用PIL模块进行的,照片的格式是.tif。但当图像通道数大于4个的时候,PIL模块就无法对tif图像进行读取了,因此我用gdal模块重新定义了数据的读取部分,自定义的读取图像函数是把图像转换为Numpy.ndarray类型,因为functions.to.tensor()函数只支持PIL.Image对象和Numpy.ndarray对象。但functions.to.tensor()对于数组输入时会如我上面所说的,把ndarray进行维度变换了,因此就碰到了模型不收敛的问题。我输入的图像是(5,600,600),在to.tensor()函数输出就变成了(600,3,600),由于我的高和宽都是600,因此我第一时间没法分辨第一个600到底是高还是宽。我把输出的tensor用transpose(1,0)函数交换了维度,变成了(3,600,600),但实际上现在是(C,W,H)。所以模型在验证时一直计算错误,导致无法收敛。后来我是用tersor.permute(1,2,0)把维度变回(C,H,W),模型就正常了。

下面我用图像展示一下(C,H,W)和(C,W,H)的对比,方便大家理解为什么出错。
在这里插入图片描述
假如红色的是GT(正确的标签),可以看到样本已经被扭曲了,所以验证集无法准确验证,模型就无法收敛。

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值