1 矩阵方面
1.1 torch.unsqueeze(input, dim, out=None)
作用:拓展维度
参数:
tensor (Tensor) – 输入张量
dim (int) – 插入维度的索引
out (Tensor, optional) – 返回张量
import torch
x = torch.Tensor([1, 2, 3, 4]) # torch.Tensor是默认的tensor类型(torch.FlaotTensor)的简称。
print('-' * 50)
print(x) # tensor([1., 2., 3., 4.])
print(x.size()) # torch.Size([4])
print(x.dim()) # 1
print(x.numpy()) # [1. 2. 3. 4.]
print('-' * 50)
print(torch.unsqueeze(x, 0)) # tensor([[1., 2., 3., 4.]])一维变二维
print(torch.unsqueeze(x, 0).size()) # torch.Size([1, 4])
print(torch.unsqueeze(x, 0).dim()) # 2
print(torch.unsqueeze(x, 0).numpy()) # [[1. 2. 3. 4.]]
print('-' * 50)
print(torch.unsqueeze(x, 1))
# tensor([[1.],
# [2.],
# [3.],
# [4.]])
print(torch.unsqueeze(x, 1).size()) # torch.Size([4, 1])
print(torch.unsqueeze(x, 1).dim()) # 2
print('-' * 50)
print(torch.unsqueeze(x, -1))
#相当于torch.unsqueeze(x, 1)
# tensor([[1.],
# [2.],
# [3.],
# [4.]])
print(torch.unsqueeze(x, -1).size()) # torch.Size([4, 1])
print(torch.unsqueeze(x, -1).dim()) # 2
1.2 torch.squeeze_()
这里提一下torch.unsqueeze()和torch.unsqueeze_()
torch的F_()函数和F()函数作用上是一致的,只是F_()函数是作用于变量本身,也就是不占用额外内存(in place),比如:
import torch
x=torch.tensor([1,2,3,4])
a=torch.unsqueeze(x,0)
print(a) # tensor([[1., 2., 3., 4.]])
print(x.unzqueeze_(0)) #tensor([[1., 2., 3., 4.]])
print(x) #tensor([[1., 2., 3., 4.]])
# x本身的值已经改变
1.3 torch.squeeze(input, dim=None, out=None)
作用:降维
将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)
m = torch.zeros(2, 1, 2, 1, 2)
print(m.size()) # torch.Size([2, 1, 2, 1, 2])
n = torch.squeeze(m)
print(n.size()) # torch.Size([2, 2, 2])
n = torch.squeeze(m, 1) # 当给定dim时,那么挤压操作只在给定维度上
print(n.size()) # torch.Size([2, 2, 1, 2])
n = torch.squeeze(m, 0) #给定的dim不是1,不会挤压
print(n.size()) # torch.Size([2, 1, 2, 1, 2])
n = torch.squeeze(m, 2)
print(n.size()) # torch.Size([2, 1, 2, 1, 2])
n = torch.squeeze(m, 3)
print(n.size()) # torch.Size([2, 1, 2, 2])
1.4 torch.view(dim1, dim2)
作用:相当于resize
import torch
a=torch.Tensor([[[1,2,3],[4,5,6]]])
b=torch.Tensor([1,2,3,4,5,6])
print(a.view(1,6)) #tensor([[1., 2., 3., 4., 5., 6.]])
print(b.view(1,6))
print(a.view(3,2))
#tensor([[1., 2.],
# [3., 4.],
# [5., 6.]])
1.4 torch.permute()
作用:将tensor维度换位