1. torch.unsqueeze 详解
torch.unsqueeze(input, dim, out=None)
- 作用:扩展维度
返回一个新的张量,对输入的既定位置插入维度 1
- 注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
如果dim为负,则将会被转化dim+input.dim()+1
- 参数:
tensor (Tensor)
– 输入张量dim (int)
– 插入维度的索引out (Tensor, optional)
– 结果张量
import torch
# 创建一个二维张量(矩阵)
x = torch.tensor([[1, 2], [3, 4]])
print(x.size())
x维度是 2 * 2
# 在第0维增加一个维度
y = torch.unsqueeze(x, 0)
print(y.size())
在第0维增加了一个维度,y维度是 1 * 2 * 2
2. unsqueeze_
和 unsqueeze
的区别
PyTorch中的 XXX_ 和 XXX 实现的功能都是相同的,唯一不同的是前者进行的是 in_place 操作。
unsqueeze_
和 unsqueeze
实现一样的功能,区别在于 unsqueeze_
是 in_place 操作,即 unsqueeze
不会对使用 unsqueeze
的 tensor 进行改变,想要获取 unsqueeze
后的值必须赋予个新值, unsqueeze_
则会对自己改变。
print("-" * 50)
a = torch.Tensor([1, 2, 3, 4])
print(a)
# tensor([1., 2., 3., 4.])
b = torch.unsqueeze(a, 1)
print(b)
# tensor([[1.],
# [2.],
# [3.],
# [4.]])
print(a)
# tensor([1., 2., 3., 4.])
print("-" * 50)
a = torch.Tensor([1, 2, 3, 4])
print(a)
# tensor([1., 2., 3., 4.])
print(a.unsqueeze_(1))
# tensor([[1.],
# [2.],
# [3.],
# [4.]])
print(a)
# tensor([[1.],
# [2.],
# [3.],
# [4.]])
3. torch.squeeze 详解
- 作用:降维
torch.squeeze(input, dim=None, out=None)
将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)
当给定dim时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B), squeeze(input, 0)
将会保持张量不变,只有用 squeeze(input, 1)
,形状会变成 (A×B)。
- 注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
- 参数:
input (Tensor)
– 输入张量dim (int, optional)
– 如果给定,则input只会在给定维度挤压out (Tensor, optional)
– 输出张量
为何只去掉 1 呢?
多维张量本质上就是一个变换,如果维度是 1 ,那么,1 仅仅起到扩充维度的作用,而没有其他用途,因而,在进行降维操作时,为了加快计算,是可以去掉这些 1 的维度。
import torch
m = torch.zeros(2, 1, 2, 1, 2)
print(m.size()) # torch.Size([2, 1, 2, 1, 2])
# 去掉所有大小为1的维度
n = torch.squeeze(m)
print(n.size()) # torch.Size([2, 2, 2])
# 去掉指定维度上大小为1的维度,如果指定的维度不是1,则不会对该维度进行操作
n = torch.squeeze(m, 0) # 当给定dim时,那么挤压操作只在给定维度上
print(n.size()) # torch.Size([2, 1, 2, 1, 2])
# 去掉指定维度上大小为1的维度,如果指定的维度是1,则会对该维度进行删除
n = torch.squeeze(m, 1)
print(n.size()) # torch.Size([2, 2, 1, 2])
当然你也可以使用:
x.squeeze() 来替换 torch.squeeze(x)