1. unsqueeze(input, dim, out=None)函数 - 升维作用
在指定的地方上增加一个维度
0(-2) [行扩展]: 表示在张量最外层增加一个中括号变成第一维
1(-1) [列扩展]:表示
>>> input = torch.arange(0,6)
>>> input
tensor([0, 1, 2, 3, 4, 5])
>>> input.shape
torch.Size([6])
>>> print(input.unsqueeze(0))
tensor([[0, 1, 2, 3, 4, 5]])
>>> print(input.unsqueeze(0).shape)
torch.Size([1, 6])
>>> print(input.unsqueeze(1))
tensor([[0],
[1],
[2],
[3],
[4],
[5]])
>>> print(input.unsqueeze(1).shape)
torch.Size([6, 1])
2. squeeze(input,dim,out=None) 降维函数
将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)
小例子
如果是一个列表的tensor例如x变量想要转换成相同维度的tensor可以采用如下方式:
1)循环遍历列表中每个张量s,先使用unsqueeze(0)将每个张量s升维。
形状由torch.Size([3])变为torch.Size([1, 3])
【变化前:tensor([0, 1, 2]) 变化后:tensor([[0, 1, 2]])】
2)同时使用torch.cat()将其拼接起来 dim=0 表示横向拼接,否则竖向拼接
dim = 0 结果:
tensor([[0, 1, 2],
[1, 0, 2],
[1, 2, 0],
[2, 1, 0]])
dim = 1 结果:
tensor([[0, 1, 2, 1, 0, 2, 1, 2, 0, 2, 1, 0]])
>>> import torch
>>> x = [torch.tensor([0,1,2]),torch.tensor([1,0,2]),torch.tensor([1,2,0]),torch.tensor([2,1,0]),]
>>> x
[tensor([0, 1, 2]), tensor([1, 0, 2]), tensor([1, 2, 0]), tensor([2, 1, 0])]
>>> x = torch.cat([s.unsqueeze(0) for s in l],0)
>>> x
tensor([[0, 1, 2],
[1, 0, 2],
[1, 2, 0],
[2, 1, 0]])
3. max()的用法
更加详细参见我的另一篇文章:Pytorch学习-torch.max()和min()深度解析
non_final_next_states.max(1)[1].detach()
# 行维度 .max(1)[0] 返回values的最大值列表 .max(1)[1]返回最大值index列表
# 列维度 .max(0)[0] 返回values的最大值列表 .max(0)[1]返回最大值index列表
4. detach() 和detach_()
torch.detach() - 返回一个新的没有梯度的tensor [生成一个新的tensor]
返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。
即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
torch.detach_() - 直接修改该tensor[对其本身的更改],将其设置为无自动计算梯度的张量
将一个tensor从创建它的图中分离,并把它设置成叶子tensor
5. torch.Tensor和torch.tensor的区别
在Pytorch中,Tensor和tensor都用于生成新的张量。
torch.Tensor() 生成单精度浮点型张量
- torch.Tensor()是Python类,更明确的说,是默认张量类型torch.FloatTensor()的别名,torch.Tensor([1,2]) 会调用Tensor类的构造函数__init__,生成单精度浮点类型的张量。
torch.tensor() 根据原始data生成对应类型的张量
torch.tensor()仅仅是Python的函数,函数原型是:
torch.tensor(data, dtype=None, device=None, requires_grad=False)
其中data可以是:list, tuple, array, scalar等类型。
torch.tensor()可以从data中的数据部分做拷贝(而不是直接引用),根据原始数据类型生成相应的torch.LongTensor,torch.FloatTensor,torch.DoubleTensor。