unsqueeze,指定维度插入尺寸为1的新张量,下例中,起初(2,6),2为第0维,运行生成,(1,2,6)生成一个三维的张量。
import torch
A = torch.arange(12.0).reshape(2,6)
print(A)
B = torch.unsqueeze(A,dim=0)
print(B.shape)
print(B)
tensor([[ 0., 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10., 11.]])
torch.Size([1, 2, 6])
tensor([[[ 0., 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10., 11.]]])
dim=1时,运行结果如下:
tensor([[ 0., 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10., 11.]])
torch.Size([2, 1, 6])
tensor([[[ 0., 1., 2., 3., 4., 5.]],
[[ 6., 7., 8., 9., 10., 11.]]])
torch.squeeze()函数:移除所有维度为1的维度,[2,1,6]是三维的,[1,6]是二维的。
import torch
A = torch.arange(12.0).reshape(2,6)
B = torch.unsqueeze(A,dim=1)
print(B.shape)#(2,1,6),2为0维
print(B)
C = B.squeeze()
print(C.shape)
print(C)
D = torch.squeeze(B,dim=1)#移除指定维度为1的维度
print(D)
print(D.shape)
torch.Size([2, 1, 6])
tensor([[[ 0., 1., 2., 3., 4., 5.]],
[[ 6., 7., 8., 9., 10., 11.]]])
torch.Size([2, 6])
tensor([[ 0., 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10., 11.]])
tensor([[ 0., 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10., 11.]])
torch.Size([2, 6])