对张量shape操作函数汇总,附代码(reshape(), torch.cat(), view(), torch.stack(), transpose(), permute(), unsqueeze)

'''调整torch.tensor张量的shape的函数汇总'''
import torch
import numpy as np
from einops import rearrange, reduce, repeat

x = [
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12]
    ],
    [
        [13, 14, 15, 16],
        [17, 18, 19, 20],
        [21, 22, 23, 24]
    ]
]
# 将x转化为张量,数据类型为浮点数
x = torch.tensor(x).float()
print(x.shape)  # torch.Size([2, 3, 4])

# einops中的rearrange()函数,支持多个维度上的shape变换,组成张量的元素保持不变
x1_1 = rearrange(x, 'h w c -> c h w')  # 调换前两个shape
print(x1_1)
print(x1_1.shape)
"""
tensor([[[ 1.,  5.,  9.],
         [13., 17., 21.]],
        [[ 2.,  6., 10.],
         [14., 18., 22.]],
        [[ 3.,  7., 11.],
         [15., 19., 23.]],
        [[ 4.,  8., 12.],
         [16., 20., 24.]]])
torch.Size([4, 2, 3])
"""

x1_2 = rearrange(x, 'h (w1 w) c -> (w1 h) w c', w1=3)  # 将第0维度的shape×3,第1维度的shape除以3
print(x1_2)
print(x1_2.shape)
"""
tensor([[[ 1.,  2.,  3.,  4.]],
        [[13., 14., 15., 16.]],
        [[ 5.,  6.,  7.,  8.]],
        [[17., 18., 19., 20.]],
        [[ 9., 10., 11., 12.]],
        [[21., 22., 23., 24.]]])
torch.Size([6, 1, 4])
"""

x1_3 = rearrange(x, 'h w c -> 1 h w c')  # 张量扩维(维数),三维到四维,组成张量的元素不变
print(x1_3)
print(x1_3.shape)
"""
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.]],
         [[13., 14., 15., 16.],
          [17., 18., 19., 20.],
          [21., 22., 23., 24.]]]])
torch.Size([1, 2, 3, 4])
"""

x1_4 = rearrange(x, 'h w c -> (h w c) 1')  # 张量降维(维数),组成张量的元素不变
print(x1_4)
print(x1_4.shape)
"""
tensor([[ 1.],
        [ 2.],
        [ 3.],
        [ 4.],
        [ 5.],
        [ 6.],
        [ 7.],
        [ 8.],
        [ 9.],
        [10.],
        [11.],
        [12.],
        [13.],
        [14.],
        [15.],
        [16.],
        [17.],
        [18.],
        [19.],
        [20.],
        [21.],
        [22.],
        [23.],
        [24.]])
torch.Size([24, 1])
"""

# repeat()函数通过在shape的某个维度(dim)上复制组成张量的元素实现对shape的调整
x2_1 = repeat(x, 'h w c -> h w c a', a=2)  # 张量扩维(维数)
print(x2_1)
print(x2_1.shape)
"""
tensor([[[[ 1.,  1.],
          [ 2.,  2.],
          [ 3.,  3.],
          [ 4.,  4.]],
         [[ 5.,  5.],
          [ 6.,  6.],
          [ 7.,  7.],
          [ 8.,  8.]],
         [[ 9.,  9.],
          [10., 10.],
          [11., 11.],
          [12., 12.]]],
        [[[13., 13.],
          [14., 14.],
          [15., 15.],
          [16., 16.]],
         [[17., 17.],
          [18., 18.],
          [19., 19.],
          [20., 20.]],
         [[21., 21.],
          [22., 22.],
          [23., 23.],
          [24., 24.]]]])
torch.Size([2, 3, 4, 2])
"""

x2_2 = repeat(x, 'h w c -> (h1 h) w (c1 c)', h1=2, c1=2)  # 对shape的任意维度进行扩增,通过元素复制
print(x2_2)
print(x2_2.shape)
"""
tensor([[[ 1.,  2.,  3.,  4.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.,  9., 10., 11., 12.]],
        [[13., 14., 15., 16., 13., 14., 15., 16.],
         [17., 18., 19., 20., 17., 18., 19., 20.],
         [21., 22., 23., 24., 21., 22., 23., 24.]],
        [[ 1.,  2.,  3.,  4.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.,  9., 10., 11., 12.]],
        [[13., 14., 15., 16., 13., 14., 15., 16.],
         [17., 18., 19., 20., 17., 18., 19., 20.],
         [21., 22., 23., 24., 21., 22., 23., 24.]]])
torch.Size([4, 3, 8])
"""

# reduce()函数可以对张量的shape降维(维数),也可以在特定维度上减小,必须指定池化方式('max', 'min', 'mean')
x3_1 = reduce(x, 'b h w -> h w', 'max')  # 'max' 代表采用最大池化降维,不同的池化类型导致最终不同的元素组成
print(x3_1)
print(x3_1.shape)
"""
tensor([[13., 14., 15., 16.],
        [17., 18., 19., 20.],
        [21., 22., 23., 24.]])
torch.Size([3, 4])
"""

x3_2 = reduce(x, 'b h w -> h w', 'min')  # 'min' 代表最小池化,即选择对应位置的最小元素进行池化
print(x3_2)
print(x3_2.shape)
"""
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]])
torch.Size([3, 4])
"""

x3_3 = reduce(x, 'b h w -> h w', 'mean')  # 'mean'代表平均池化,代表使用每个位置所有元素的平均进行池化
print(x3_3)
print(x3_3.shape)
"""
tensor([[ 7.,  8.,  9., 10.],
        [11., 12., 13., 14.],
        [15., 16., 17., 18.]])
torch.Size([3, 4])
"""

# reduce()也可调整shape特定维度上值
x3_4 = reduce(x, '(b1 b) h (w1 w) -> b h w', 'max', b1=2, w1=2)  # 最大池化
print(x3_4)
print(x3_4.shape)
"""
tensor([[[15., 16.],
         [19., 20.],
         [23., 24.]]])
torch.Size([1, 3, 2])
"""

x3_5 = reduce(x, '(b1 b) h (w1 w) -> b h w', 'min', b1=2, w1=2)  # 最小池化
print(x3_5)
print(x3_5.shape)
"""
tensor([[[ 1.,  2.],
         [ 5.,  6.],
         [ 9., 10.]]])
torch.Size([1, 3, 2])
"""

x3_6 = reduce(x, '(b1 b) h (w1 w) -> b h w', 'mean', b1=2, w1=2)  # 平均池化
print(x3_6)
print(x3_6.shape)
"""
tensor([[[ 8.,  9.],
         [12., 13.],
         [16., 17.]]])
torch.Size([1, 3, 2])
"""

print(x)
print(x.shape)
"""
tensor([[[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.]],
        [[13., 14., 15., 16.],
         [17., 18., 19., 20.],
         [21., 22., 23., 24.]]])
torch.Size([2, 3, 4])
"""
# x.view(new_shape)函数将张量x.shape修改为new_shape,但必须保持组成的元素不变,所在new_shape是有一定限制的
x4_1 = x.view(6, 4)
print(x4_1)
print(x4_1.shape)
"""
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.],
        [17., 18., 19., 20.],
        [21., 22., 23., 24.]])
torch.Size([6, 4])
"""

# x.view(new_shape)函数在new_shape中使用-1代表shape在该维度的值自动计算,需要满足组成元素不变
x4_2 = x.view(-1, 2, 3)
print(x4_2)
print(x4_2.shape)
"""
tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.]],
        [[ 7.,  8.,  9.],
         [10., 11., 12.]],
        [[13., 14., 15.],
         [16., 17., 18.]],
        [[19., 20., 21.],
         [22., 23., 24.]]])
torch.Size([4, 2, 3])
"""

x4_3 = x.view(2, -1, 2)
print(x4_3)
print(x4_3.shape)
"""
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.],
         [ 7.,  8.],
         [ 9., 10.],
         [11., 12.]],
        [[13., 14.],
         [15., 16.],
         [17., 18.],
         [19., 20.],
         [21., 22.],
         [23., 24.]]])
torch.Size([2, 6, 2])
"""
# torch.cat()函数,对两个张量在特定维度进行拼接,要求两个张量的shape在除拼接维度上均相等,否则报错
y = torch.tensor([[[ 1.],
                   [ 1.],
                   [ 1.]],
                  [[2.],
                   [2.],
                   [2.]]])
print(y.shape)  # torch.Size([2, 3, 1])

x5_1 = torch.cat((x, y), dim=2)  # (x, y)是拼接的两个张量,有先后顺序,dim=2是拼接维度的索引
print(x5_1)
print(x5_1.shape)
"""
tensor([[[ 1.,  2.,  3.,  4.,  1.],
         [ 5.,  6.,  7.,  8.,  1.],
         [ 9., 10., 11., 12.,  1.]],
        [[13., 14., 15., 16.,  2.],
         [17., 18., 19., 20.,  2.],
         [21., 22., 23., 24.,  2.]]])
torch.Size([2, 3, 5])
"""

# x.reshape(new_shape)函数改变张量x.shape为指定的new_shape,需满足组成张量的元素保持不变
x6_1 = x.reshape(1, 6, 4)
print(x6_1)
print(x6_1.shape)
"""
tensor([[[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.],
         [13., 14., 15., 16.],
         [17., 18., 19., 20.],
         [21., 22., 23., 24.]]])
torch.Size([1, 6, 4])
"""
# unsqueeze/squeeze() 在指定维度上添加或移除一个值为1的维度,从而达到升维/降维(维数)
x7_1 = x.unsqueeze(0)  # 在索引为0的维度上添加一个值为1的维度
print(x7_1)
print(x7_1.shape)
"""
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.]],
         [[13., 14., 15., 16.],
          [17., 18., 19., 20.],
          [21., 22., 23., 24.]]]])
torch.Size([1, 2, 3, 4])
"""

x7_2 = x7_1.squeeze(0)  # 去除x.shape的dim=0的维度(若shape在该维度上为1),否则保持不变;若不指定索引,如x.squeeze()则在所有维度上均类此操作
print(x7_2)
print(x7_2.shape)
"""
tensor([[[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.]],
        [[13., 14., 15., 16.],
         [17., 18., 19., 20.],
         [21., 22., 23., 24.]]])
torch.Size([2, 3, 4])
"""

# torch.stack((a, b), dim)函数实现升维(维数+1),并改变shape在指定维度(dim)的值
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]])

print(a.shape)  # torch.Size([3, 3])
print(b.shape)  # torch.Size([3, 3])

# 沿dim=0进行拼接:因为使用torch.stack()会总维数+1变成三维张量,沿dim=0可理解为在新三维张量shape的第0个维度上进行连接,即在最外面增加一层嵌套列表!
c_1 = torch.stack((a, b), dim=0)
print(c_1)
print(c_1.shape)
"""
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],
        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]])
torch.Size([2, 3, 3])
"""

# 沿dim=1拼接,对张量a、b而言,dim=1即代表行,即在行方向上进行张量a与b的拼接.所以原本对a而言,第一行只有[1, 2, 3]现在要拼接上[10, 20, 30],其他两行也是如此
c_2 = torch.stack((a, b), dim=1)
print(c_2)
print(c_2.shape)
"""
tensor([[[ 1,  2,  3],
         [10, 20, 30]],
        [[ 4,  5,  6],
         [40, 50, 60]],
        [[ 7,  8,  9],
         [70, 80, 90]]])
torch.Size([3, 2, 3])
"""

# 沿dim=2进行拼接,对张量a、b而言,dim=1对应列,即在列方向上进行拼接
# 不失一般性,对a的第一列,原本是[1, 4, 7]的转置,现在要拼接上[10, 40, 70]的转置,其他两列也是这样。
c_3 = torch.stack((a, b), dim=2)
print(c_3)
print(c_3.shape)
"""
tensor([[[ 1, 10],
         [ 2, 20],
         [ 3, 30]],
        [[ 4, 40],
         [ 5, 50],
         [ 6, 60]],
        [[ 7, 70],
         [ 8, 80],
         [ 9, 90]]])
torch.Size([3, 3, 2])
"""

# x.transpose(dim1, dim2)函数交换shape在dim1和dim2两个维度的值(这里的dim类似于维度的索引)
x8_1 = x.transpose(2, 0)  # 这里交换张量x的dim=0维度和dim=2维度
print(x8_1)
print(x8_1.shape)
"""
tensor([[[ 1., 13.],
         [ 5., 17.],
         [ 9., 21.]],
        [[ 2., 14.],
         [ 6., 18.],
         [10., 22.]],
        [[ 3., 15.],
         [ 7., 19.],
         [11., 23.]],
        [[ 4., 16.],
         [ 8., 20.],
         [12., 24.]]])
torch.Size([4, 3, 2])
"""

# x.permute()函数与x.transpose(dim1, dim2)类似,允许一次性交换x.shape的多个维度
x8_2 = x.permute(1, 2, 0)
print(x8_2)
print(x8_2.shape)
"""
tensor([[[ 1., 13.],
         [ 2., 14.],
         [ 3., 15.],
         [ 4., 16.]],
        [[ 5., 17.],
         [ 6., 18.],
         [ 7., 19.],
         [ 8., 20.]],
        [[ 9., 21.],
         [10., 22.],
         [11., 23.],
         [12., 24.]]])
torch.Size([3, 4, 2])
"""

# x.unbind(dim)函数在指定维度(dim)上进行拆分,移除shape在该维度的值,并返回拆分后的(多个)张量组成的元组.拆分张量的数量==shape在拆分维度(dim)的值
x9_1 = x.unbind(0)  # 沿dim=0进行拆分,并去除shape在该维度的值,返回拆分后的多个张量组成的元组(turple())
print(x9_1)
print(x9_1[0].shape)  # x9_1[0].shape == x9_1[1].shape
"""
(tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]]), 
 tensor([[13., 14., 15., 16.],
        [17., 18., 19., 20.],
        [21., 22., 23., 24.]]))
torch.Size([3, 4])
"""

x9_2 = x.unbind(dim=1)  # 沿dim=1进行拆分,返回拆分后的多个张量组成的元组(turple())
print(x9_2)
print(x9_2[0].shape)
"""
(tensor([[ 1.,  2.,  3.,  4.],
        [13., 14., 15., 16.]]), 
 tensor([[ 5.,  6.,  7.,  8.],
        [17., 18., 19., 20.]]), 
 tensor([[ 9., 10., 11., 12.],
        [21., 22., 23., 24.]]))
torch.Size([2, 4])
"""

x9_3 = x.unbind(2)  # 沿dim=2进行拆分,返回拆分后的多个张量组成的元组(turple())
print(x9_3)
print(x9_3[3].shape)
"""
(tensor([[ 1.,  5.,  9.],
        [13., 17., 21.]]), 
 tensor([[ 2.,  6., 10.],
        [14., 18., 22.]]), 
 tensor([[ 3.,  7., 11.],
        [15., 19., 23.]]), 
 tensor([[ 4.,  8., 12.],
        [16., 20., 24.]]))
torch.Size([2, 3])
"""

  • 15
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值