65 篇文章 4 订阅

# 1 为何要增减维度

numpy中squeeze函数，无unsqueeze函数，numpy中增加维度用np.expand_dims(x, axis)函数，可参考链接
torch的tensor中，两个函数都有。

# 2 numpy中的squeeze 函数

arr_1 = numpy.squeeze(arr, axis = None)


arr表示输入的数组；
axis的取值可为None或0，默认为None，表示删除所有shape为1的维度。axis为0表示删除 一层 shape为1的维度

import numpy as np

arr = np.array([[[[1,2,3],[4,5,6]]]])
print(type(arr), arr, arr.shape, sep='\n')
print("==========================")

arr_1 = np.squeeze(arr, axis=0)
print(type(arr_1), arr_1, arr_1.shape, sep='\n')
print("==========================")

arr_2 = np.squeeze(arr, axis=None)
print(type(arr_2), arr_2, arr_2.shape, sep='\n')


<class 'numpy.ndarray'>
[[[1 2 3]
[4 5 6]]]
(1, 2, 3)
==========================
<class 'numpy.ndarray'>
[[1 2 3]
[4 5 6]]
(2, 3)
==========================
<class 'numpy.ndarray'>
[[1 2 3]
[4 5 6]]
(2, 3)


# 3 torch中的squeeze 函数

import torch

arr = torch.Tensor(1, 3, 1, 5)
print(type(arr), arr, arr.shape, sep='\n')
print("==========================")

# 里面的数字表示压缩哪个维度，依旧只有维度为1才能压
arr_1 = arr.squeeze(0)          # 压缩第一维度，且第一维度是1，可压缩
print(type(arr_1), arr_1, arr_1.shape, sep='\n')
print("==========================")

arr_2 = arr.squeeze(1)        # 压缩第二维度，但第二维度不是1，故不可压缩
print(type(arr_2), arr_2, arr_2.shape, sep='\n')
print("==========================")

arr_3 = arr.squeeze(2)        # 压缩第三维度，且第三维度是1，可压缩
print(type(arr_3), arr_3, arr_3.shape, sep='\n')


<class 'torch.Tensor'>
tensor([[[[1.9349e-19, 4.5445e+30, 4.7429e+30, 7.1354e+31, 7.1118e-04]],

[[1.7444e+28, 7.3909e+22, 1.8727e+31, 1.4182e-19, 4.6168e+24]],

[[4.2964e+24, 1.2514e-14, 8.9634e-33, 7.1345e+31, 7.1118e-04]]]])
torch.Size([1, 3, 1, 5])
==========================
<class 'torch.Tensor'>
tensor([[[1.9349e-19, 4.5445e+30, 4.7429e+30, 7.1354e+31, 7.1118e-04]],

[[1.7444e+28, 7.3909e+22, 1.8727e+31, 1.4182e-19, 4.6168e+24]],

[[4.2964e+24, 1.2514e-14, 8.9634e-33, 7.1345e+31, 7.1118e-04]]])
torch.Size([3, 1, 5])
==========================
<class 'torch.Tensor'>
tensor([[[[1.9349e-19, 4.5445e+30, 4.7429e+30, 7.1354e+31, 7.1118e-04]],

[[1.7444e+28, 7.3909e+22, 1.8727e+31, 1.4182e-19, 4.6168e+24]],

[[4.2964e+24, 1.2514e-14, 8.9634e-33, 7.1345e+31, 7.1118e-04]]]])
torch.Size([1, 3, 1, 5])
==========================
<class 'torch.Tensor'>
tensor([[[1.9349e-19, 4.5445e+30, 4.7429e+30, 7.1354e+31, 7.1118e-04],
[1.7444e+28, 7.3909e+22, 1.8727e+31, 1.4182e-19, 4.6168e+24],
[4.2964e+24, 1.2514e-14, 8.9634e-33, 7.1345e+31, 7.1118e-04]]])
torch.Size([1, 3, 5])


# 4 torch中的unsqueeze 函数

import torch

arr = torch.Tensor(3, 5)
print(type(arr), arr, arr.shape, sep='\n')
print("==========================")

# 本身是二维，增加一维变三维，可通过0,1,2三个数字来控制维度增加到哪
arr_1 = arr.unsqueeze(0)
print(type(arr_1), arr_1, arr_1.shape, sep='\n')
print("==========================")

arr_2 = arr.unsqueeze(1)
print(type(arr_2), arr_2, arr_2.shape, sep='\n')
print("==========================")

arr_3 = arr.unsqueeze(2)        # 数字再大就报错了
print(type(arr_3), arr_3, arr_3.shape, sep='\n')


<class 'torch.Tensor'>
tensor([[3.2483e+33, 1.9690e-19, 6.8589e+22, 1.3340e+31, 1.1708e-19],
[7.2128e+22, 9.2216e+29, 7.5546e+31, 1.6932e+22, 3.0728e+32],
[2.9514e+29, 2.8940e+12, 7.5338e+28, 1.8037e+28, 3.4740e-12]])
torch.Size([3, 5])
==========================
<class 'torch.Tensor'>
tensor([[[3.2483e+33, 1.9690e-19, 6.8589e+22, 1.3340e+31, 1.1708e-19],
[7.2128e+22, 9.2216e+29, 7.5546e+31, 1.6932e+22, 3.0728e+32],
[2.9514e+29, 2.8940e+12, 7.5338e+28, 1.8037e+28, 3.4740e-12]]])
torch.Size([1, 3, 5])
==========================
<class 'torch.Tensor'>
tensor([[[3.2483e+33, 1.9690e-19, 6.8589e+22, 1.3340e+31, 1.1708e-19]],

[[7.2128e+22, 9.2216e+29, 7.5546e+31, 1.6932e+22, 3.0728e+32]],

[[2.9514e+29, 2.8940e+12, 7.5338e+28, 1.8037e+28, 3.4740e-12]]])
torch.Size([3, 1, 5])
==========================
<class 'torch.Tensor'>
tensor([[[3.2483e+33],
[1.9690e-19],
[6.8589e+22],
[1.3340e+31],
[1.1708e-19]],

[[7.2128e+22],
[9.2216e+29],
[7.5546e+31],
[1.6932e+22],
[3.0728e+32]],

[[2.9514e+29],
[2.8940e+12],
[7.5338e+28],
[1.8037e+28],
[3.4740e-12]]])
torch.Size([3, 5, 1])

• 7
点赞
• 11
收藏
觉得还不错? 一键收藏
• 打赏
• 0
评论
07-25 926
04-24 158
11-24 207
12-07 3033
02-09 14万+
01-26 4141
04-24 78
07-13 5万+
10-20 720

### “相关推荐”对你有帮助么？

• 非常没帮助
• 没帮助
• 一般
• 有帮助
• 非常有帮助

¥2 ¥4 ¥6 ¥10 ¥20

1.余额是钱包充值的虚拟货币，按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载，可以购买VIP、C币套餐、付费专栏及课程。