示例文件 test.py
import torch
A = torch.arange(6)
A = A.reshape(1, 1, 2, 3, -1)
B = A.squeeze()
print(A)
print(B)
C = B.unsqueeze(0)
D = B.unsqueeze(1)
E = B.unsqueeze(2)
print(C)
print(D)
print(E)
终端命令行和运行结果
<user>python test.py
tensor([[[[[0],
[1],
[2]],
[[3],
[4],
[5]]]]])
tensor([[0, 1, 2],
[3, 4, 5]])
tensor([[[0, 1, 2],
[3, 4, 5]]])
tensor([[[0, 1, 2]],
[[3, 4, 5]]])
tensor([[[0],
[1],
[2]],
[[3],
[4],
[5]]])
从运算结果来看:
1)tensor.squeeze() 将输入张量形状中的1去除并返回。如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)。
2)tensor.unsqueeze() 扩展维度,返回一个新的张量,对输入的既定位置插入维度1。