一切源于做李宏毅作业时的一个报错:
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
首先为什么会有这个错误?
从报错的字面意思来看:view函数的大小跟输入tensor的大小和步长不兼容(至少有一维跨越了两个连续的子空间),使用reshape代替。
错误信息已经很明确了,使用reshape替代view函数就可以解决问题。
但是不由引出几个问题:
为什么view与输入不兼容?跨越两个连续子空间是什么意思?
这其实是同一个问题,view函数不能操作非连续的tensor。要理解非连续tensor首先要了解tensor的底层存储逻辑。
参考4 | PyTorch张量操作:底层存储逻辑-腾讯云开发者社区-腾讯云和PyTorch中的contiguous
简单地说,tensor的数据其实是按行优先的方式存储为1维的连续空间的,这跟我们学习的数组是类似的,以下是3*4的tensor a,底层存储在一行。
Tensor是一个上层的封装类,它内部通过offset、stride、size等元数据对底层数据进行索引。例如要访问tensor a[2,3]元素,要通过stride和tensor索引来计算目标索引。例如这里stride=(4, 1),表示沿着行的方向,要跳过4步才能到达下一行,沿着列的方向,要跳过1步到达下一列。因此a[2,3]的目标索引是2*4+3*1=11,也就是如果offset=0,那就是第12个位置。
何为contiguous tensor?
下面的定义不知道最初的出处是哪,不太好理解。但括号里的那句话给了一些提示,对于2维tensor,数据是沿着行的方向存储在内存的。也就是前面说的行优先。其实多维的情况也类似,是沿着维度从左向右顺序存储。
A tensor whose values are laid out in the storage starting from the rightmost dimension onward (that is, moving along rows for a 2D tensor) is defined as contiguous
什么情况下会出现non-contiguous tensor?
tensor少数的操作会导致非连续,参考python - What does .contiguous() do in PyTorch? - Stack Overflow
narrow(),view(),expand()和transpose()
这些操作其实没有改变数据在内存中存储的顺序,它们只是修改了Tensor的元数据,例如前面的3*4的Tensor调用transpose()转置后,会变成4*3的Tensor,但是底层的存储还是同一份,只是把stride改成(1, 4)了。
这时候,Tensor就是non-contiguous的,因为它违反了上面的定义。它不是按行优先存储在内存,而是通过列优先。
至此,我应该讲明白了上面问题出现的原因了。简单说就是,view函数只能用于连续Tensor(在内存中按照Tensor行优先存储的Tensor),一些操作(比如转置),会修改Tensor的元数据,导致Tensor不再遵循行优先存储,从而变成非连续Tensor。
那为什么view不能用于非连续Tensor呢?
文章PyTorch中的contiguous有详细的解释,一句话就是因为性能问题,CPU读取连续的一行数据是非常快速的,但是如果读取不连续的一行数据就要做多次的IO。
那为什么我的代码出现非连续Tensor?
因为使用了Data Augmentation,也就是对训练数据进行了转换,而转换函数往往包含转置等操作。
为什么我训练的时候没有问题,测试的时候才有问题?
因为我训练的时候使用了transforms.Compose函数,该函数的列表最后添加transforms.ToTensor(),在ToTensor的代码中可以看到其调用了contiguous()函数。而在测试的时候,我采用Test Time Augmentation(TTA)用每个transform函数进行转换之后没有调用ToTensor。所以出问题了。
if isinstance(pic, np.ndarray):
# handle numpy array
if pic.ndim == 2:
pic = pic[:, :, None]
img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
# backward compatibility
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
为什么在TPU跑测试没有问题呢?
这个我也搞不懂为什么,可能是TPU中对Tensor有特殊的适配,没有找到相关资料,这个问题就丢下了。
附录:
除了文中引用的资料,还参考以下资料,不得不佩服这些技术大拿清晰的逻辑和表达:
Contigious vs non-contigious tensor - PyTorch Forums
python - What is the difference between contiguous and non-contiguous arrays? - Stack Overflow
转换怎么把连续的tensor变成不连续的。
Numpy Axes, Explained - Sharp Sight
Pytorch学习笔记——Contiguous vs non-contiguous tensor-CSDN博客
PyTorch – How to check if a tensor is contiguous or not?
python - What functions or modules require contiguous input? - Stack Overflow
torchvision.transforms.functional — Torchvision main documentation