Pytorch view函数报错解决

本文解析了在PyTorch中view函数与输入tensor不兼容的错误原因,强调了contiguoustensor的重要性,特别是当涉及非连续tensor时。作者通过详细解释了tensor的底层存储逻辑,展示了如何在DataAugmentation和TestTimeAugmentation中遇到这种问题,并讨论了TPU中可能的特殊处理。
摘要由CSDN通过智能技术生成

一切源于做李宏毅作业时的一个报错:

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

首先为什么会有这个错误?

从报错的字面意思来看:view函数的大小跟输入tensor的大小和步长不兼容(至少有一维跨越了两个连续的子空间),使用reshape代替。

错误信息已经很明确了,使用reshape替代view函数就可以解决问题。

但是不由引出几个问题:

为什么view与输入不兼容?跨越两个连续子空间是什么意思?

这其实是同一个问题,view函数不能操作非连续的tensor。要理解非连续tensor首先要了解tensor的底层存储逻辑。

参考4 | PyTorch张量操作:底层存储逻辑-腾讯云开发者社区-腾讯云PyTorch中的contiguous

简单地说,tensor的数据其实是按行优先的方式存储为1维的连续空间的,这跟我们学习的数组是类似的,以下是3*4的tensor a,底层存储在一行。

Tensor是一个上层的封装类,它内部通过offset、stride、size等元数据对底层数据进行索引。例如要访问tensor a[2,3]元素,要通过stride和tensor索引来计算目标索引。例如这里stride=(4, 1),表示沿着行的方向,要跳过4步才能到达下一行,沿着列的方向,要跳过1步到达下一列。因此a[2,3]的目标索引是2*4+3*1=11,也就是如果offset=0,那就是第12个位置。

何为contiguous tensor?

下面的定义不知道最初的出处是哪,不太好理解。但括号里的那句话给了一些提示,对于2维tensor,数据是沿着行的方向存储在内存的。也就是前面说的行优先。其实多维的情况也类似,是沿着维度从左向右顺序存储。

A tensor whose values are laid out in the storage starting from the rightmost dimension onward (that is, moving along rows for a 2D tensor) is defined as contiguous

什么情况下会出现non-contiguous tensor?

tensor少数的操作会导致非连续,参考python - What does .contiguous() do in PyTorch? - Stack Overflow

narrow(),view(),expand()和transpose()

这些操作其实没有改变数据在内存中存储的顺序,它们只是修改了Tensor的元数据,例如前面的3*4的Tensor调用transpose()转置后,会变成4*3的Tensor,但是底层的存储还是同一份,只是把stride改成(1, 4)了。

这时候,Tensor就是non-contiguous的,因为它违反了上面的定义。它不是按行优先存储在内存,而是通过列优先。

至此,我应该讲明白了上面问题出现的原因了。简单说就是,view函数只能用于连续Tensor(在内存中按照Tensor行优先存储的Tensor),一些操作(比如转置),会修改Tensor的元数据,导致Tensor不再遵循行优先存储,从而变成非连续Tensor。

那为什么view不能用于非连续Tensor呢?

文章PyTorch中的contiguous有详细的解释,一句话就是因为性能问题,CPU读取连续的一行数据是非常快速的,但是如果读取不连续的一行数据就要做多次的IO。

那为什么我的代码出现非连续Tensor?

因为使用了Data Augmentation,也就是对训练数据进行了转换,而转换函数往往包含转置等操作。

为什么我训练的时候没有问题,测试的时候才有问题?

因为我训练的时候使用了transforms.Compose函数,该函数的列表最后添加transforms.ToTensor(),在ToTensor的代码中可以看到其调用了contiguous()函数。而在测试的时候,我采用Test Time Augmentation(TTA)用每个transform函数进行转换之后没有调用ToTensor。所以出问题了。

if isinstance(pic, np.ndarray):
        # handle numpy array
        if pic.ndim == 2:
            pic = pic[:, :, None]

        img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
        # backward compatibility
        if isinstance(img, torch.ByteTensor):
            return img.to(dtype=default_float_dtype).div(255)
        else:
            return img

为什么在TPU跑测试没有问题呢?

这个我也搞不懂为什么,可能是TPU中对Tensor有特殊的适配,没有找到相关资料,这个问题就丢下了。

附录:

除了文中引用的资料,还参考以下资料,不得不佩服这些技术大拿清晰的逻辑和表达:

Contigious vs non-contigious tensor - PyTorch Forums

https://medium.com/analytics-vidhya/pytorch-contiguous-vs-non-contiguous-tensor-view-understanding-view-reshape-73e10cdfa0dd

https://medium.com/analytics-vidhya/a-thorough-understanding-of-numpy-strides-and-its-application-in-data-processing-e40eab1c82fe

python - What is the difference between contiguous and non-contiguous arrays? - Stack Overflow

转换怎么把连续的tensor变成不连续的。

Numpy Axes, Explained - Sharp Sight

Pytorch学习笔记——Contiguous vs non-contiguous tensor-CSDN博客

PyTorch – How to check if a tensor is contiguous or not?

python - What functions or modules require contiguous input? - Stack Overflow

torchvision.transforms.functional — Torchvision main documentation

  • 17
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值