Pytorch API接口说明

函数说明:Pytorch官网

持续更新…

torch.squeeze

torch.squeeze(input, dim=None) → Tensor
Returns a tensor with all specified dimensions of input of size 1 removed.

例如,如果input的形状为: (A×1×B×C×1×D),那么 input.squeeze() 将具有以下形状: (A×B×C×D) 。

当给出 dim 时,仅在给定的维度上进行压缩操作。如果输入的形状为: (A×1×B) ,squeeze(input, 0) 保持张量不变,但squeeze(input, 1) 会将张量压缩到形状 (A×B) 。

注意:返回的张量与输入张量共享存储,因此更改一个张量的内容将更改另一个张量的内容。

torch.unsqueeze

torch.unsqueeze(input, dim) → Tensor
Returns a new tensor with a dimension of size one inserted at the specified position.

返回的张量与该张量共享相同的基础数据。

可以使用 [-input.dim() - 1, input.dim() + 1) 范围内的dim值。负dim等效于在dim = dim + input.dim() + 1 处应用unsqueeze()。

torch.max

torch.max(input) → Tensor
Returns the maximum value of all elements in the input tensor.


>>> a = torch.randn(1, 3)
>>> a
tensor([[ 0.6763,  0.7445, -2.2369]])
>>> torch.max(a)
tensor(0.7445)

torch.max(input, dim, keepdim=False, *, out=None)
Parameters:
input (Tensor) – the input tensor.
dim (int) – the dimension to reduce.
keepdim (bool) – whether the output tensor has dim retained or not. Default: False.
Keyword Arguments:
out (tuple, optional) – the result tuple of two output tensors (max, max_indices)

返回一个命名元组(values,indices),其中values是给定维度dim中输入张量每行的最大值。 indexs 是找到的每个最大值(argmax)的索引位置。

如果 keepdim 为 True,则输出张量与输入的大小相同,但维度 dim 的大小为 1。否则,dim 会被压缩,导致输出张量的维度比输入减1。

>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222,  0.8475],
        [ 1.1949, -1.1127, -2.2379, -0.6702],
        [ 1.5717, -0.9207,  0.1297, -1.8768],
        [-0.6172,  1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))

parameters()

功能:模型参数,用循环拿到参数
返回:layer-param(参数的迭代器),list类型

named_parameters()

功能:模型参数,用循环拿到参数
返回:layer-name,layer-param,打包成一个元组然后再存到list

只保存可学习、可被更新的参数,model.buffer()中的参数不包含;
模型参数tensor的require_grad属性都是True。

state_dict()

功能:模型参数
返回:layer_name : layer_param的键值信息存储为dict。

保存所有layer中的所有参数;
模型参数tensor的require_grad属性都是False。

load_state_dict()

torch.load_state_dict()函数是用于将预训练的参数权重加载到新的模型之中。

# 加载官方模型参数到模型中
model.load_state_dict(weights_dict, strict=False)

当strict=True,要求预训练权重层数的键值与新构建的模型中的权重层数名称完全吻合;如果不一致,则上述代码就会报错:说key对应不上。

当strict=False,可以解决这个问题,训练权重中与新构建网络中匹配层的键值就进行使用,没有的就默认初始化。

torch.tensor

功能:将其他数据类型转换为tensor型数据。默认整型数据类型为torch.int64,浮点型为torch.float32

torch.tensor(data,dtype=None,device=None,requires_grad=False,pin_memory=False)

参数:

  1. data:要转换的数据,可以是列表、元组、NumPy、ndarray、标量和其他类型
  2. dtype:数据类型,默认和data类型一致,用来修改变量类型
  3. device ( torch.device, optional):构造张量的设备。如果 None 并且数据是张量,则使用数据设备。如果None 并且数据不是张量,则结果张量在 CPU 上构建。
  4. requires_grad:是否保留对应的梯度信息,默认为False。
  5. pin_memory:若值为True,返回的张量将分配在固定内存中。仅适用于
    CPU 张量。默认值:False。
import torch
a=torch.tensor([1,2])
b=torch.tensor([1,2],dtype=torch.float32)
print(a)
print(b)

输出:

tensor([1, 2])
tensor([1., 2.])

torch.Tensor/FloatTensor

功能:Python类,FloatTensor的简称,生成一个数据类型为 32 位浮点数的张量,如果没传入数据就返回空张量,如果有列表或者 narray ,则返回其对应张量。

x = torch.FloatTensor([1,2])
y = torch.FloatTensor(1,2)
print(x)
print(y)

输出:

tensor([1., 2.])
tensor([[1.1632e+33, 3.2598e-12]])

torch.gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
Gathers values along an axis specified by dim.

对一个3D tensor的输出可以表示为

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

input和index必须具有相同的维度数,即规定input和index是同维张量,即input是2维张量,index也必须是2维张量。对于所有维度 d != dim.out ,还要求 index.size(d) <= input.size(d)具有与index相同的形状。请注意,input和index不会相互广播。

Parameters:
input (Tensor) – the source tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to gather
Keyword Arguments:
sparse_grad (bool, optional) – If True, gradient w.r.t. input will be a sparse tensor.
out (Tensor, optional) – the destination tensor

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])

参考PyTorch中的高级索引方法——gather详解
input是张量A

A = tensor([[0, 1],
        [2, 3],
        [4, 5]])

在这里插入图片描述
在这里插入图片描述

detach

功能:detach函数将其从计算图中分离出来,避免在反向传播过程中产生新的梯度。返回的是一个新的张量,但是它与原始张量共享底层存储。

  • 28
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

余加木

想喝蜜雪冰城柠檬水(≧≦)/

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值