'''调整torch.tensor张量的shape的函数汇总'''
import torch
import numpy as np
from einops import rearrange, reduce, repeat
x = [
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
],
[
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]
]
]
# 将x转化为张量,数据类型为浮点数
x = torch.tensor(x).float()
print(x.shape) # torch.Size([2, 3, 4])
# einops中的rearrange()函数,支持多个维度上的shape变换,组成张量的元素保持不变
x1_1 = rearrange(x, 'h w c -> c h w') # 调换前两个shape
print(x1_1)
print(x1_1.shape)
"""
tensor([[[ 1., 5., 9.],
[13., 17., 21.]],
[[ 2., 6., 10.],
[14., 18., 22.]],
[[ 3., 7., 11.],
[15., 19., 23.]],
[[ 4., 8., 12.],
[16., 20., 24.]]])
torch.Size([4, 2, 3])
"""
x1_2 = rearrange(x, 'h (w1 w) c -> (w1 h) w c', w1=3) # 将第0维度的shape×3,第1维度的shape除以3
print(x1_2)
print(x1_2.shape)
"""
tensor([[[ 1., 2., 3., 4.]],
[[13., 14., 15., 16.]],
[[ 5., 6., 7., 8.]],
[[17., 18., 19., 20.]],
[[ 9., 10., 11., 12.]],
[[21., 22., 23., 24.]]])
torch.Size([6, 1, 4])
"""
x1_3 = rearrange(x, 'h w c -> 1 h w c') # 张量扩维(维数),三维到四维,组成张量的元素不变
print(x1_3)
print(x1_3.shape)
"""
tensor([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]]])
torch.Size([1, 2, 3, 4])
"""
x1_4 = rearrange(x, 'h w c -> (h w c) 1') # 张量降维(维数),组成张量的元素不变
print(x1_4)
print(x1_4.shape)
"""
tensor([[ 1.],
[ 2.],
[ 3.],
[ 4.],
[ 5.],
[ 6.],
[ 7.],
[ 8.],
[ 9.],
[10.],
[11.],
[12.],
[13.],
[14.],
[15.],
[16.],
[17.],
[18.],
[19.],
[20.],
[21.],
[22.],
[23.],
[24.]])
torch.Size([24, 1])
"""
# repeat()函数通过在shape的某个维度(dim)上复制组成张量的元素实现对shape的调整
x2_1 = repeat(x, 'h w c -> h w c a', a=2) # 张量扩维(维数)
print(x2_1)
print(x2_1.shape)
"""
tensor([[[[ 1., 1.],
[ 2., 2.],
[ 3., 3.],
[ 4., 4.]],
[[ 5., 5.],
[ 6., 6.],
[ 7., 7.],
[ 8., 8.]],
[[ 9., 9.],
[10., 10.],
[11., 11.],
[12., 12.]]],
[[[13., 13.],
[14., 14.],
[15., 15.],
[16., 16.]],
[[17., 17.],
[18., 18.],
[19., 19.],
[20., 20.]],
[[21., 21.],
[22., 22.],
[23., 23.],
[24., 24.]]]])
torch.Size([2, 3, 4, 2])
"""
x2_2 = repeat(x, 'h w c -> (h1 h) w (c1 c)', h1=2, c1=2) # 对shape的任意维度进行扩增,通过元素复制
print(x2_2)
print(x2_2.shape)
"""
tensor([[[ 1., 2., 3., 4., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 5., 6., 7., 8.],
[ 9., 10., 11., 12., 9., 10., 11., 12.]],
[[13., 14., 15., 16., 13., 14., 15., 16.],
[17., 18., 19., 20., 17., 18., 19., 20.],
[21., 22., 23., 24., 21., 22., 23., 24.]],
[[ 1., 2., 3., 4., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 5., 6., 7., 8.],
[ 9., 10., 11., 12., 9., 10., 11., 12.]],
[[13., 14., 15., 16., 13., 14., 15., 16.],
[17., 18., 19., 20., 17., 18., 19., 20.],
[21., 22., 23., 24., 21., 22., 23., 24.]]])
torch.Size([4, 3, 8])
"""
# reduce()函数可以对张量的shape降维(维数),也可以在特定维度上减小,必须指定池化方式('max', 'min', 'mean')
x3_1 = reduce(x, 'b h w -> h w', 'max') # 'max' 代表采用最大池化降维,不同的池化类型导致最终不同的元素组成
print(x3_1)
print(x3_1.shape)
"""
tensor([[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]])
torch.Size([3, 4])
"""
x3_2 = reduce(x, 'b h w -> h w', 'min') # 'min' 代表最小池化,即选择对应位置的最小元素进行池化
print(x3_2)
print(x3_2.shape)
"""
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]])
torch.Size([3, 4])
"""
x3_3 = reduce(x, 'b h w -> h w', 'mean') # 'mean'代表平均池化,代表使用每个位置所有元素的平均进行池化
print(x3_3)
print(x3_3.shape)
"""
tensor([[ 7., 8., 9., 10.],
[11., 12., 13., 14.],
[15., 16., 17., 18.]])
torch.Size([3, 4])
"""
# reduce()也可调整shape特定维度上值
x3_4 = reduce(x, '(b1 b) h (w1 w) -> b h w', 'max', b1=2, w1=2) # 最大池化
print(x3_4)
print(x3_4.shape)
"""
tensor([[[15., 16.],
[19., 20.],
[23., 24.]]])
torch.Size([1, 3, 2])
"""
x3_5 = reduce(x, '(b1 b) h (w1 w) -> b h w', 'min', b1=2, w1=2) # 最小池化
print(x3_5)
print(x3_5.shape)
"""
tensor([[[ 1., 2.],
[ 5., 6.],
[ 9., 10.]]])
torch.Size([1, 3, 2])
"""
x3_6 = reduce(x, '(b1 b) h (w1 w) -> b h w', 'mean', b1=2, w1=2) # 平均池化
print(x3_6)
print(x3_6.shape)
"""
tensor([[[ 8., 9.],
[12., 13.],
[16., 17.]]])
torch.Size([1, 3, 2])
"""
print(x)
print(x.shape)
"""
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]])
torch.Size([2, 3, 4])
"""
# x.view(new_shape)函数将张量x.shape修改为new_shape,但必须保持组成的元素不变,所在new_shape是有一定限制的
x4_1 = x.view(6, 4)
print(x4_1)
print(x4_1.shape)
"""
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]])
torch.Size([6, 4])
"""
# x.view(new_shape)函数在new_shape中使用-1代表shape在该维度的值自动计算,需要满足组成元素不变
x4_2 = x.view(-1, 2, 3)
print(x4_2)
print(x4_2.shape)
"""
tensor([[[ 1., 2., 3.],
[ 4., 5., 6.]],
[[ 7., 8., 9.],
[10., 11., 12.]],
[[13., 14., 15.],
[16., 17., 18.]],
[[19., 20., 21.],
[22., 23., 24.]]])
torch.Size([4, 2, 3])
"""
x4_3 = x.view(2, -1, 2)
print(x4_3)
print(x4_3.shape)
"""
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.],
[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.],
[17., 18.],
[19., 20.],
[21., 22.],
[23., 24.]]])
torch.Size([2, 6, 2])
"""
# torch.cat()函数,对两个张量在特定维度进行拼接,要求两个张量的shape在除拼接维度上均相等,否则报错
y = torch.tensor([[[ 1.],
[ 1.],
[ 1.]],
[[2.],
[2.],
[2.]]])
print(y.shape) # torch.Size([2, 3, 1])
x5_1 = torch.cat((x, y), dim=2) # (x, y)是拼接的两个张量,有先后顺序,dim=2是拼接维度的索引
print(x5_1)
print(x5_1.shape)
"""
tensor([[[ 1., 2., 3., 4., 1.],
[ 5., 6., 7., 8., 1.],
[ 9., 10., 11., 12., 1.]],
[[13., 14., 15., 16., 2.],
[17., 18., 19., 20., 2.],
[21., 22., 23., 24., 2.]]])
torch.Size([2, 3, 5])
"""
# x.reshape(new_shape)函数改变张量x.shape为指定的new_shape,需满足组成张量的元素保持不变
x6_1 = x.reshape(1, 6, 4)
print(x6_1)
print(x6_1.shape)
"""
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]])
torch.Size([1, 6, 4])
"""
# unsqueeze/squeeze() 在指定维度上添加或移除一个值为1的维度,从而达到升维/降维(维数)
x7_1 = x.unsqueeze(0) # 在索引为0的维度上添加一个值为1的维度
print(x7_1)
print(x7_1.shape)
"""
tensor([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]]])
torch.Size([1, 2, 3, 4])
"""
x7_2 = x7_1.squeeze(0) # 去除x.shape的dim=0的维度(若shape在该维度上为1),否则保持不变;若不指定索引,如x.squeeze()则在所有维度上均类此操作
print(x7_2)
print(x7_2.shape)
"""
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]])
torch.Size([2, 3, 4])
"""
# torch.stack((a, b), dim)函数实现升维(维数+1),并改变shape在指定维度(dim)的值
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
print(a.shape) # torch.Size([3, 3])
print(b.shape) # torch.Size([3, 3])
# 沿dim=0进行拼接:因为使用torch.stack()会总维数+1变成三维张量,沿dim=0可理解为在新三维张量shape的第0个维度上进行连接,即在最外面增加一层嵌套列表!
c_1 = torch.stack((a, b), dim=0)
print(c_1)
print(c_1.shape)
"""
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]])
torch.Size([2, 3, 3])
"""
# 沿dim=1拼接,对张量a、b而言,dim=1即代表行,即在行方向上进行张量a与b的拼接.所以原本对a而言,第一行只有[1, 2, 3]现在要拼接上[10, 20, 30],其他两行也是如此
c_2 = torch.stack((a, b), dim=1)
print(c_2)
print(c_2.shape)
"""
tensor([[[ 1, 2, 3],
[10, 20, 30]],
[[ 4, 5, 6],
[40, 50, 60]],
[[ 7, 8, 9],
[70, 80, 90]]])
torch.Size([3, 2, 3])
"""
# 沿dim=2进行拼接,对张量a、b而言,dim=1对应列,即在列方向上进行拼接
# 不失一般性,对a的第一列,原本是[1, 4, 7]的转置,现在要拼接上[10, 40, 70]的转置,其他两列也是这样。
c_3 = torch.stack((a, b), dim=2)
print(c_3)
print(c_3.shape)
"""
tensor([[[ 1, 10],
[ 2, 20],
[ 3, 30]],
[[ 4, 40],
[ 5, 50],
[ 6, 60]],
[[ 7, 70],
[ 8, 80],
[ 9, 90]]])
torch.Size([3, 3, 2])
"""
# x.transpose(dim1, dim2)函数交换shape在dim1和dim2两个维度的值(这里的dim类似于维度的索引)
x8_1 = x.transpose(2, 0) # 这里交换张量x的dim=0维度和dim=2维度
print(x8_1)
print(x8_1.shape)
"""
tensor([[[ 1., 13.],
[ 5., 17.],
[ 9., 21.]],
[[ 2., 14.],
[ 6., 18.],
[10., 22.]],
[[ 3., 15.],
[ 7., 19.],
[11., 23.]],
[[ 4., 16.],
[ 8., 20.],
[12., 24.]]])
torch.Size([4, 3, 2])
"""
# x.permute()函数与x.transpose(dim1, dim2)类似,允许一次性交换x.shape的多个维度
x8_2 = x.permute(1, 2, 0)
print(x8_2)
print(x8_2.shape)
"""
tensor([[[ 1., 13.],
[ 2., 14.],
[ 3., 15.],
[ 4., 16.]],
[[ 5., 17.],
[ 6., 18.],
[ 7., 19.],
[ 8., 20.]],
[[ 9., 21.],
[10., 22.],
[11., 23.],
[12., 24.]]])
torch.Size([3, 4, 2])
"""
# x.unbind(dim)函数在指定维度(dim)上进行拆分,移除shape在该维度的值,并返回拆分后的(多个)张量组成的元组.拆分张量的数量==shape在拆分维度(dim)的值
x9_1 = x.unbind(0) # 沿dim=0进行拆分,并去除shape在该维度的值,返回拆分后的多个张量组成的元组(turple())
print(x9_1)
print(x9_1[0].shape) # x9_1[0].shape == x9_1[1].shape
"""
(tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]]),
tensor([[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]))
torch.Size([3, 4])
"""
x9_2 = x.unbind(dim=1) # 沿dim=1进行拆分,返回拆分后的多个张量组成的元组(turple())
print(x9_2)
print(x9_2[0].shape)
"""
(tensor([[ 1., 2., 3., 4.],
[13., 14., 15., 16.]]),
tensor([[ 5., 6., 7., 8.],
[17., 18., 19., 20.]]),
tensor([[ 9., 10., 11., 12.],
[21., 22., 23., 24.]]))
torch.Size([2, 4])
"""
x9_3 = x.unbind(2) # 沿dim=2进行拆分,返回拆分后的多个张量组成的元组(turple())
print(x9_3)
print(x9_3[3].shape)
"""
(tensor([[ 1., 5., 9.],
[13., 17., 21.]]),
tensor([[ 2., 6., 10.],
[14., 18., 22.]]),
tensor([[ 3., 7., 11.],
[15., 19., 23.]]),
tensor([[ 4., 8., 12.],
[16., 20., 24.]]))
torch.Size([2, 3])
"""
对张量shape操作函数汇总,附代码(reshape(), torch.cat(), view(), torch.stack(), transpose(), permute(), unsqueeze)
于 2024-06-11 23:58:46 首次发布