torch.squeeze():对数据的维度进行压缩,当dim不设值时,去掉输入的tensor的所有维度为1的维度。
import torch
x = torch.rand(size=(2, 1, 2, 1, 2))
# 5
print(x.dim())
print(x.size())
# torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x)
# torch.Size([2, 2, 2])
print(y.size())
当dim为某一整数,判断该维的维度是否为1,若是则去掉,否则不变。
import torch
x = torch.rand(size=(1, 2, 1, 2, 1, 2))
# 6
print(x.dim())
print(x.size())
# torch.Size([1, 2, 1, 2, 1, 2])
y = torch.squeeze(x, 0) # dim=0表示第一维,且第一维的维度为1,所以去掉
# torch.Size([2, 1, 2, 1, 2])
print(y.size())
import torch
x = torch.rand(size=(1, 2, 1, 2, 1, 2))
# 6
print(x.dim())
print(x.size())
# torch.Size([1, 2, 1, 2, 1, 2])
y = torch.squeeze(x, 2) # dim=2表示第三维,且第三维的维度为1,所以去掉
# torch.Size([1, 2, 2, 1, 2])
print(y.size())
torch.unsqueeze():主要是对数据维度进行扩充
x = torch.tensor([1, 2, 3, 4])
print(x.size())
# torch.Size([4])
x = torch.unsqueeze(x, 0)
print(x.size())
# torch.Size([1, 4])
print(x)
# tensor([[1, 2, 3, 4]])
x = torch.tensor([1, 2, 3, 4])
print(x.size())
x = torch.unsqueeze(x, 1)
print(x.size())
# torch.Size([4, 1])
print(x)
"""tensor([[1],
[2],
[3],
[4]])"""