简介
在pytorch的代码中,经常涉及到tensor形状的变换,而常用的操作就是通过view,reshape,permute这些函数来实现。这几个函数从最后结果来看,都可以改变矩阵的形状,但是对于数据的具体操作其实还是有些许区别。
本文通过具体实例来解释这几者之间的区别。
举个栗子
首先,我们定义一个4个维度【2,2,2,2】的的tensor,并展示它的基本属性。
data = np.arange(16).reshape((2,2,2,2))
tensor = torch.tensor(data)
show_attribute(tensor)
这里展示的属性包括:数据的具体内容,数据排布是否连续,以及每个维度的步长(stride)。
def show_attribute(tensor):
print(f"data:\n{tensor}")
print(f"tensor is contiguous: {tensor.is_contiguous()}")
print(f"tensor stride: {tensor.stride()}")
# print(f"raw data {tensor.storage()}")
注意,对于连续排布的数据来说,最后一个维度通常步长为1。换句话说,在pytorch中,满足contiguous条件的tensor中最后一个维度的数据,在内存中是连续排列的。
contiguous介绍
在pytorch中,存储的tensor按照行优先(row major)的顺序存储。
如下图,当访问每一行的下一个元素,因为在内存中连续排布,你只需要前进一步(stride)。但是访问下一列的元素,需要前进四步(strides)。(对于列优先的数据,访问下一列的元素,只需要前进一步)
在PyTorch中,有些对Tensor的操作并不实际改变tensor的具体数据,只是改变如何根据索引检索到tensor的byte location的方式(元数据),比如
narrow()
, view()
, expand()
, transpose()
,permute()
[1]。
当对于矩阵进行旋转等操作后,原始数据的排列方式不变,但是访问新矩阵同一行的下一个数据(比如从下图的0 到4),需要的步数将变为4。此时,我们认为矩阵将不再符合contiguous要求。
可以预料的是,当矩阵不满足contiguous时,遍历的效率会变低。并且,其他一些操作会无法进行。后文将举例具体介绍。
permute介绍
permute的作用是调整原始矩阵各个轴(axis)的先后顺序。如下图,对于原始矩阵进行permute操作后,矩阵将不满足contiguous条件,并且,每一个维度对应的stride也改变(常规的矩阵需要满足最后一个维度步长为1)。
而transpose的功能和permute类似,不同的是,transpose只能调整两个轴的相关关系。而permute可以改变多个轴之间的相对关系。
值得注意的是,如果通过storge()函数,展示tensor的原始数据,它的排列方式和原始状态一致,这进一步说明,permute改变了数据每个维度的stride等属性,但是不改变原始数据的排列方式和具体内容。
view介绍
view(x) 会按照想要的矩阵尺寸,得到一个tensor。
比如,我希望把此前的原始tensor转变为想要的形状(16,1)。该操作不会影响contiguous属性。
但是,如果想把此前permute后的矩阵进行view的操作,则失败了。
这一步说明,在进行permute之后需要恢复矩阵的contiguous的属性,否则可能会影响后续操作。
幸好在在torch中,恢复的方式也很简单。如下图,对于此前的tensor通过调用contiguous()恢复了连续性。
如果查看新tensor的原始数据,可以看到,原始数据的排列方式也不一样了。
reshape介绍
view是torch早期就有的一个函数,但是后来,为了优化它在解决上述问题上的困境,reshape函数出现了。reshape函数功能和view一致,但是,当面对不连续(is_contiguous = False)的数据时,它会主动创建一个新对象,去避免报错。如下图。即使对于不连续的对象进行操作,reshape也没有问题。
总结
在pytorch中,为了避免数据拷贝,构造的开销,permute会通过修改元数据来简化操作,但是这会导致其他操作,比如view等的不兼容,因为我们不得不通过调用contiguous来构造新的对象,虽然这样开销更大,但是可以便利后续的自由操作。
理解这一系列操作的内在逻辑,可以帮助我们更好的分析和解决bug。
参考文献
[1] 为什么需要Tensor.contiguous()? PyTorch中Tensor.contiguous()作用分析_tensor contiguous-CSDN博客
[2] python - What is the difference between contiguous and non-contiguous arrays? - Stack Overflow
[3] python - What's the difference between reshape and view in pytorch? - Stack Overflow