Pytorch 一文搞懂view,reshape和permute,transpose用法

简介

在pytorch的代码中,经常涉及到tensor形状的变换,而常用的操作就是通过view,reshape,permute这些函数来实现。这几个函数从最后结果来看,都可以改变矩阵的形状,但是对于数据的具体操作其实还是有些许区别。

本文通过具体实例来解释这几者之间的区别。

举个栗子

首先,我们定义一个4个维度【2,2,2,2】的的tensor,并展示它的基本属性。

data = np.arange(16).reshape((2,2,2,2))
tensor = torch.tensor(data)
show_attribute(tensor)

这里展示的属性包括:数据的具体内容,数据排布是否连续,以及每个维度的步长(stride)。

def show_attribute(tensor):
    print(f"data:\n{tensor}")
    print(f"tensor is contiguous: {tensor.is_contiguous()}")
    print(f"tensor stride: {tensor.stride()}")
#     print(f"raw data {tensor.storage()}")

注意,对于连续排布的数据来说,最后一个维度通常步长为1。换句话说,在pytorch中,满足contiguous条件的tensor中最后一个维度的数据,在内存中是连续排列的。

contiguous介绍

在pytorch中,存储的tensor按照行优先(row major)的顺序存储。

如下图,当访问每一行的下一个元素,因为在内存中连续排布,你只需要前进一步(stride)。但是访问下一列的元素,需要前进四步(strides)。(对于列优先的数据,访问下一列的元素,只需要前进一步)

PyTorch中,有些对Tensor的操作并不实际改变tensor的具体数据,只是改变如何根据索引检索到tensor的byte location的方式(元数据),比如 

narrow()view()expand()transpose()permute() [1]。

当对于矩阵进行旋转等操作后,原始数据的排列方式不变,但是访问新矩阵同一行的下一个数据(比如从下图的0 到4),需要的步数将变为4。此时,我们认为矩阵将不再符合contiguous要求。

可以预料的是,当矩阵不满足contiguous时,遍历的效率会变低。并且,其他一些操作会无法进行。后文将举例具体介绍。

permute介绍

permute的作用是调整原始矩阵各个轴(axis)的先后顺序。如下图,对于原始矩阵进行permute操作后,矩阵将不满足contiguous条件,并且,每一个维度对应的stride也改变(常规的矩阵需要满足最后一个维度步长为1)。

而transpose的功能和permute类似,不同的是,transpose只能调整两个轴的相关关系。而permute可以改变多个轴之间的相对关系。

原始矩阵信息
新矩阵信息,最后一个维度步长变为8!

值得注意的是,如果通过storge()函数,展示tensor的原始数据,它的排列方式和原始状态一致,这进一步说明,permute改变了数据每个维度的stride等属性,但是不改变原始数据的排列方式和具体内容。

view介绍

view(x) 会按照想要的矩阵尺寸,得到一个tensor。

比如,我希望把此前的原始tensor转变为想要的形状(16,1)。该操作不会影响contiguous属性。

但是,如果想把此前permute后的矩阵进行view的操作,则失败了。

这一步说明,在进行permute之后需要恢复矩阵的contiguous的属性,否则可能会影响后续操作。

幸好在在torch中,恢复的方式也很简单。如下图,对于此前的tensor通过调用contiguous()恢复了连续性。

如果查看新tensor的原始数据,可以看到,原始数据的排列方式也不一样了。

reshape介绍

view是torch早期就有的一个函数,但是后来,为了优化它在解决上述问题上的困境,reshape函数出现了。reshape函数功能和view一致,但是,当面对不连续(is_contiguous = False)的数据时,它会主动创建一个新对象,去避免报错。如下图。即使对于不连续的对象进行操作,reshape也没有问题。

总结

在pytorch中,为了避免数据拷贝,构造的开销,permute会通过修改元数据来简化操作,但是这会导致其他操作,比如view等的不兼容,因为我们不得不通过调用contiguous来构造新的对象,虽然这样开销更大,但是可以便利后续的自由操作。

理解这一系列操作的内在逻辑,可以帮助我们更好的分析和解决bug。

参考文献

[1] 为什么需要Tensor.contiguous()? PyTorch中Tensor.contiguous()作用分析_tensor contiguous-CSDN博客

[2] python - What is the difference between contiguous and non-contiguous arrays? - Stack Overflow

 [3] python - What's the difference between reshape and view in pytorch? - Stack Overflow

首先,让我们理解这些PyTorch中的函数作用: - `reshape()`: 改变张量的大小,但保持元素总数不变。 - `unsqueeze()`: 在指定维度添加一个尺寸为1的轴。 - `squeeze()`: 删除所有大小为1的维度。 - `transpose()`: 轴之间的元素交换,可以看作是`permute()`的一个特殊情况,通常用于改变矩阵的行和列顺序。 - `permute()`: 完全地重新排列张量的维度。 下面是一个例子,展示了如何使用这些函数将给定形状的图像数据转换成不同的形状: ```python import torch # 假设input_data是 (1, 28, 28, 3) 的torch tensor input_data = torch.randn(1, 28, 28, 3) # 假设它是一个随机生成的数据 # 1. reshape to (28, 28, 3) reshaped_28x28x3 = input_data.permute(1, 2, 0).contiguous().view(28, 28, 3) # 2. unsqueeze to add a new dimension at dim=0 unsqueeze_dim0 = input_data.unsqueeze(0) # (1, 28, 28, 3) -> (3, 28, 28, 3) # 3. squeeze to remove size 1 dimensions squeezed = input_data.squeeze() # (1, 28, 28, 3) -> (28, 28, 3) if the first dimension is 1 # 4. transpose and permute to change the order of dimensions transposed = input_data.transpose(1, 2) # (1, 28, 28, 3) -> (1, 28, 3, 28) permuted = input_data.permute(2, 0, 1) # (1, 28, 28, 3) -> (3, 1, 28, 28) print(f"Original shape: {input_data.shape}") print(f"reshaped_28x28x3: {reshaped_28x28x3.shape}") print(f"unsqueeze_dim0: {unsqueeze_dim0.shape}") print(f"squeezed: {squeezed.shape} (if squeezed first dim was 1)") print(f"transposed: {transposed.shape}") print(f"permuted: {permuted.shape}") ``` 请注意,如果输入数据的某些维度原本就是1,直接`squeeze()`可能会丢失信息,除非在之前先确认不需要保留那些维度。同时,`unsqueeze()`和`permute()`/`transpose()`的结果取决于原始数据的具体结构。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值