Torch
的 squeeze()
和unsqueeze()
函数,作用分别是降维和升维
(1)squeeze()
实现降维
例1,创建一个维度为[2, 3]
的向量,去掉其中一维度,发现并没有起作用,因为被降维的维数必须为1
才可以。
import torch
vec = torch.arange(6)
vec = vec.view(2, 3)
print(vec.shape, vec) # torch.Size([2, 3]) tensor([[0, 1, 2], [3, 4, 5]])
vec = vec.squeeze(-2)
print(vec.shape, vec) # torch.Size([2, 3]) tensor([[0, 1, 2], [3, 4, 5]])
例2,创建一个维度为[1, 6]
的向量,去掉其中一维度
import torch
vec = torch.arange(6)
vec = vec.view(1, 6)
print(vec.shape, vec) # torch.Size([1, 6]) tensor([[0, 1, 2, 3, 4, 5]])
vec = vec.squeeze(-2)
print(vec.shape, vec) # torch.Size([6]) tensor([0, 1, 2, 3, 4, 5])
注意: 输入的参数和被降的维数。
(2)unsqueeze()
实现升维
import torch
vec = torch.arange(6)
vec = vec.view(2, 3)
print(vec.shape, vec) # torch.Size([2, 3]) tensor([[0, 1, 2],[3, 4, 5]])
vec = vec.unsqueeze(0)
print(vec.shape, vec) # torch.Size([1, 2, 3]) tensor([[[0, 1, 2],[3, 4, 5]]])
注意: 输入的参数表示要升的维度。
声明: 总结学习,有问题或不当之处,可以批评指正哦,谢谢。