Pytorch学习- 小型知识点汇总 unsqueeze()/squeeze() 和 .max() 等等

1. unsqueeze(input, dim, out=None)函数 - 升维作用

参考链接

在指定的地方上增加一个维度

0(-2) [行扩展]: 表示在张量最外层增加一个中括号变成第一维
1(-1) [列扩展]:表示

>>> input = torch.arange(0,6)
>>> input
tensor([0, 1, 2, 3, 4, 5])
>>> input.shape
torch.Size([6])

>>> print(input.unsqueeze(0))
tensor([[0, 1, 2, 3, 4, 5]])
>>> print(input.unsqueeze(0).shape)
torch.Size([1, 6])

>>> print(input.unsqueeze(1))
tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5]])
>>> print(input.unsqueeze(1).shape)
torch.Size([6, 1])

2. squeeze(input,dim,out=None) 降维函数

将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)

小例子

如果是一个列表的tensor例如x变量想要转换成相同维度的tensor可以采用如下方式:
1)循环遍历列表中每个张量s,先使用unsqueeze(0)将每个张量s升维。
形状由torch.Size([3])变为torch.Size([1, 3])
【变化前:tensor([0, 1, 2]) 变化后:tensor([[0, 1, 2]])】

2)同时使用torch.cat()将其拼接起来 dim=0 表示横向拼接,否则竖向拼接
dim = 0 结果:

tensor([[0, 1, 2],
        [1, 0, 2],
        [1, 2, 0],
        [2, 1, 0]])

dim = 1 结果:

tensor([[0, 1, 2, 1, 0, 2, 1, 2, 0, 2, 1, 0]])
>>> import torch
>>> x = [torch.tensor([0,1,2]),torch.tensor([1,0,2]),torch.tensor([1,2,0]),torch.tensor([2,1,0]),]
>>> x
[tensor([0, 1, 2]), tensor([1, 0, 2]), tensor([1, 2, 0]), tensor([2, 1, 0])]
>>> x = torch.cat([s.unsqueeze(0) for s in l],0)
>>> x
tensor([[0, 1, 2],
        [1, 0, 2],
        [1, 2, 0],
        [2, 1, 0]])

3. max()的用法

更加详细参见我的另一篇文章:Pytorch学习-torch.max()和min()深度解析
non_final_next_states.max(1)[1].detach()
# 行维度 .max(1)[0] 返回values的最大值列表 .max(1)[1]返回最大值index列表
# 列维度 .max(0)[0] 返回values的最大值列表 .max(0)[1]返回最大值index列表

4. detach() 和detach_()

参考链接

torch.detach() - 返回一个新的没有梯度的tensor [生成一个新的tensor]

返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。

即使之后重新将它的requires_grad置为true,它也不会具有梯度grad

torch.detach_() - 直接修改该tensor[对其本身的更改],将其设置为无自动计算梯度的张量

将一个tensor从创建它的图中分离,并把它设置成叶子tensor

5. torch.Tensor和torch.tensor的区别

参考

在Pytorch中,Tensor和tensor都用于生成新的张量。

torch.Tensor() 生成单精度浮点型张量

  • torch.Tensor()是Python类,更明确的说,是默认张量类型torch.FloatTensor()的别名,torch.Tensor([1,2]) 会调用Tensor类的构造函数__init__,生成单精度浮点类型的张量。

torch.tensor() 根据原始data生成对应类型的张量

torch.tensor()仅仅是Python的函数,函数原型是:

torch.tensor(data, dtype=None, device=None, requires_grad=False)

其中data可以是:list, tuple, array, scalar等类型。
torch.tensor()可以从data中的数据部分做拷贝(而不是直接引用),根据原始数据类型生成相应的torch.LongTensor,torch.FloatTensor,torch.DoubleTensor。

5.torch.cat() 的用法

参考链接

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值