Pytorch 一文搞懂view,reshape和permute,transpose用法

简介

在pytorch的代码中,经常涉及到tensor形状的变换,而常用的操作就是通过view,reshape,permute这些函数来实现。这几个函数从最后结果来看,都可以改变矩阵的形状,但是对于数据的具体操作其实还是有些许区别。

本文通过具体实例来解释这几者之间的区别。

举个栗子

首先,我们定义一个4个维度【2,2,2,2】的的tensor,并展示它的基本属性。

data = np.arange(16).reshape((2,2,2,2))
tensor = torch.tensor(data)
show_attribute(tensor)

这里展示的属性包括:数据的具体内容,数据排布是否连续,以及每个维度的步长(stride)。

def show_attribute(tensor):
    print(f"data:\n{tensor}")
    print(f"tensor is contiguous: {tensor.is_contiguous()}")
    print(f"tensor stride: {tensor.stride()}")
#     print(f"raw data {tensor.storage()}")

注意,对于连续排布的数据来说,最后一个维度通常步长为1。换句话说,在pytorch中,满足contiguous条件的tensor中最后一个维度的数据,在内存中是连续排列的。

contiguous介绍

在pytorch中,存储的tensor按照行优先(row major)的顺序存储。

如下图,当访问每一行的下一个元素,因为在内存中连续排布,你只需要前进一步(stride)。但是访问下一列的元素,需要前进四步(strides)。(对于列优先的数据,访问下一列的元素,只需要前进一步)

PyTorch中,有些对Tensor的操作并不实际改变tensor的具体数据,只是改变如何根据索引检索到tensor的byte location的方式(元数据),比如 

narrow()view()expand()transpose()permute() [1]。

当对于矩阵进行旋转等操作后,原始数据的排列方式不变,但是访问新矩阵同一行的下一个数据(比如从下图的0 到4),需要的步数将变为4。此时,我们认为矩阵将不再符合contiguous要求。

可以预料的是,当矩阵不满足contiguous时,遍历的效率会变低。并且,其他一些操作会无法进行。后文将举例具体介绍。

permute介绍

permute的作用是调整原始矩阵各个轴(axis)的先后顺序。如下图,对于原始矩阵进行permute操作后,矩阵将不满足contiguous条件,并且,每一个维度对应的stride也改变(常规的矩阵需要满足最后一个维度步长为1)。

而transpose的功能和permute类似,不同的是,transpose只能调整两个轴的相关关系。而permute可以改变多个轴之间的相对关系。

原始矩阵信息
新矩阵信息,最后一个维度步长变为8!

值得注意的是,如果通过storge()函数,展示tensor的原始数据,它的排列方式和原始状态一致,这进一步说明,permute改变了数据每个维度的stride等属性,但是不改变原始数据的排列方式和具体内容。

view介绍

view(x) 会按照想要的矩阵尺寸,得到一个tensor。

比如,我希望把此前的原始tensor转变为想要的形状(16,1)。该操作不会影响contiguous属性。

但是,如果想把此前permute后的矩阵进行view的操作,则失败了。

这一步说明,在进行permute之后需要恢复矩阵的contiguous的属性,否则可能会影响后续操作。

幸好在在torch中,恢复的方式也很简单。如下图,对于此前的tensor通过调用contiguous()恢复了连续性。

如果查看新tensor的原始数据,可以看到,原始数据的排列方式也不一样了。

reshape介绍

view是torch早期就有的一个函数,但是后来,为了优化它在解决上述问题上的困境,reshape函数出现了。reshape函数功能和view一致,但是,当面对不连续(is_contiguous = False)的数据时,它会主动创建一个新对象,去避免报错。如下图。即使对于不连续的对象进行操作,reshape也没有问题。

总结

在pytorch中,为了避免数据拷贝,构造的开销,permute会通过修改元数据来简化操作,但是这会导致其他操作,比如view等的不兼容,因为我们不得不通过调用contiguous来构造新的对象,虽然这样开销更大,但是可以便利后续的自由操作。

理解这一系列操作的内在逻辑,可以帮助我们更好的分析和解决bug。

参考文献

[1] 为什么需要Tensor.contiguous()? PyTorch中Tensor.contiguous()作用分析_tensor contiguous-CSDN博客

[2] python - What is the difference between contiguous and non-contiguous arrays? - Stack Overflow

 [3] python - What's the difference between reshape and view in pytorch? - Stack Overflow

  • 28
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
你想了解关于 PyTorch 中的 ShuffleAttention 吗?ShuffleAttention 是一种注意力机制,用于在自注意力模型中对输入序列进行特征提取。它在 Transformer 模型的基础上进行了改进,通过引入 Shuffle 操作来增强模型的表达能力。 在 PyTorch 中,你可以通过自定义一个 ShuffleAttention 类来实现 ShuffleAttention。下面是一个简单的示例代码: ```python import torch import torch.nn as nn class ShuffleAttention(nn.Module): def __init__(self, dim, num_heads=8, dropout=0.1): super(ShuffleAttention, self).__init__() self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3) self.attn_dropout = nn.Dropout(dropout) self.proj = nn.Linear(dim, dim) self.proj_dropout = nn.Dropout(dropout) def forward(self, x): B, L, C = x.shape H = self.num_heads head_dim = self.head_dim qkv = self.qkv(x).reshape(B, L, 3, H, head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn_weights = (q @ k.transpose(-2, -1)) * self.scale attn_probs = nn.Softmax(dim=-1)(attn_weights) attn_probs = self.attn_dropout(attn_probs) attended_vals = attn_probs @ v attended_vals = attended_vals.transpose(1, 2).reshape(B, L, C) x = self.proj_dropout(self.proj(attended_vals)) return x ``` 这是一个简化版的 ShuffleAttention 实现,其中包含了自注意力机制的关键步骤,如计算注意力权重、进行注意力加权和投影操作等。你可以根据自己的需求进行修改和扩展。 希望以上信息能对你有所帮助!如果还有其他问题,请随时提问。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值