pytorch中reshape()、view()、permute()、transpose()总结

1. reshape() 和 view()

参考链接:
PyTorch中view的用法
pytorch中contiguous()

功能相似,但是 view() 只能操作 tensor,reshape() 可以操作 tensor 和 ndarray。view() 只能用在 contiguous 的 variable 上。如果在 view 之前用了 transpose, permute 等,需要用 contiguous() 来返回一个 contiguous copy。 view() 操作后的 tensor 和原 tensor 共享存储。
pytorch 中的 torch.reshape() 大致相当于 tensor.contiguous().view()。

import torch
import numpy as np

a = np.arange(6)
print('a:\n',a)
b = a.reshape(2,3)
print('b-np.reahspe():\n', b)
c = torch.tensor(a)
d = c.reshape(2,3)
print('d-torch.reshape():\n',d)
e = c.view(2,3)
print('e-torch.view()\n',e)

输出:

a:
 [0 1 2 3 4 5]
b-np.reahspe():
 [[0 1 2]
 [3 4 5]]
d-torch.reshape():
 tensor([[0, 1, 2],
        [3, 4, 5]], dtype=torch.int32)
e-torch.view()
 tensor([[0, 1, 2],
        [3, 4, 5]], dtype=torch.int32)
2. permute() 和 transpose()

参考链接:
pytorch — tensor.permute()和torch.transpose()

两者都是实现维度之间的交换,transpose 只能一次转换两个维度,permute 可以一次转换多个维度,permute 可以通过 transpose 组合的形式实现。在卷积神经网络中,cross-channel max pooling 就用到了这类变换。

permute:

import torch
a = torch.tensor([[[0,1,2],[3,4,5]]])
print("a.shape:", a.shape)
print('a:\n',a)
b = a.permute(0,2,1)
print("b.shape:", b.shape)
print('b:\n',b)
c = a.permute(1,2,0)
print("c.shape:", c.shape)
print('c\n',c)

输出:

a.shape: torch.Size([1, 2, 3])
a:
 tensor([[[0, 1, 2],
         [3, 4, 5]]])
b.shape: torch.Size([1, 3, 2])
b:
 tensor([[[0, 3],
         [1, 4],
         [2, 5]]])
c.shape: torch.Size([2, 3, 1])
c
 tensor([[[0],
         [1],
         [2]],

        [[3],
         [4],
         [5]]])

transpose

a = torch.tensor([[[0,1,2],[3,4,5]]])
print("a.shape:", a.shape)
print('a:\n',a)
d = a.transpose(0,1)
print("d.shape:", d.shape)
print('d\n',d)
e = a.transpose(2,1)
print("e.shape:", e.shape)
print('e\n',e)

输出:

a.shape: torch.Size([1, 2, 3])
a:
 tensor([[[0, 1, 2],
         [3, 4, 5]]])
d.shape: torch.Size([2, 1, 3])
d
 tensor([[[0, 1, 2]],

        [[3, 4, 5]]])
e.shape: torch.Size([1, 3, 2])
e
 tensor([[[0, 3],
         [1, 4],
         [2, 5]]])

reshape 与 view 可以重新设置维度;permute 和 transpose 只能 在已有的维度之间转换,并且包含转置的概念。

  • 15
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
你想了解关于 PyTorch 的 ShuffleAttention 吗?ShuffleAttention 是一种注意力机制,用于在自注意力模型对输入序列进行特征提取。它在 Transformer 模型的基础上进行了改进,通过引入 Shuffle 操作来增强模型的表达能力。 在 PyTorch ,你可以通过自定义一个 ShuffleAttention 类来实现 ShuffleAttention。下面是一个简单的示例代码: ```python import torch import torch.nn as nn class ShuffleAttention(nn.Module): def __init__(self, dim, num_heads=8, dropout=0.1): super(ShuffleAttention, self).__init__() self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3) self.attn_dropout = nn.Dropout(dropout) self.proj = nn.Linear(dim, dim) self.proj_dropout = nn.Dropout(dropout) def forward(self, x): B, L, C = x.shape H = self.num_heads head_dim = self.head_dim qkv = self.qkv(x).reshape(B, L, 3, H, head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn_weights = (q @ k.transpose(-2, -1)) * self.scale attn_probs = nn.Softmax(dim=-1)(attn_weights) attn_probs = self.attn_dropout(attn_probs) attended_vals = attn_probs @ v attended_vals = attended_vals.transpose(1, 2).reshape(B, L, C) x = self.proj_dropout(self.proj(attended_vals)) return x ``` 这是一个简化版的 ShuffleAttention 实现,其包含了自注意力机制的关键步骤,如计算注意力权重、进行注意力加权和投影操作等。你可以根据自己的需求进行修改和扩展。 希望以上信息能对你有所帮助!如果还有其他问题,请随时提问。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值