PyTorch autogradqs_tutorial 分析

本文介绍了PyTorch中的自动求导机制Autograd,它是机器学习中反向传播的关键。通过张量的requires_grad属性和计算图,可以自动计算损失函数的梯度。文中以一个简单的线性回归模型为例,展示了如何使用Autograd进行梯度计算,并介绍了如何构建神经网络并利用自动求导进行反向传播。此外,还讨论了如何禁用梯度跟踪以提高效率,以及如何计算雅可比积分。
摘要由CSDN通过智能技术生成

请添加图片描述
autogradqs_tutorial 介绍了自动求导(Autograd)的基础知识。自动求导是机器学习中重要的一环,它能够自动计算并优化损失函数的梯度。本文首先简要地介绍了 PyTorch 中的张量(Tensor)和计算图(Computation Graph)的基本概念,然后重点讲解了张量的 requires_grad 属性和如何计算张量的梯度。最后通过一个简单的线性回归模型的例子来演示自动求导的实际应用。该教程还介绍了如何使用 torch.nn 模块构建神经网络,并使用自动求导进行反向传播。这些基础知识对于深入理解和使用 PyTorch 进行机器学习非常重要。

自动微分计算 torch.autograd

训练神经网络时,最常用的算法是 反向传播(back propagation)。在该算法中,根据损失函数对给定参数的 梯度(gradient) 来调整参数(模型权重)。

为了计算这些梯度,PyTorch 内置了一个微分引擎 torch.autograd。它支持自动计算任何计算图的梯度。

考虑最简单的一层神经网络,它有输入 x,参数 wb,以及某个损失函数。它可以在 PyTorch 中如下定义:

import torch

x = torch.ones(5)  # 输入张量
y = torch.zeros(3)  # 期望输出
w = torch.randn(5, 3, requires_grad=True) # 参数w,随机初始化,设置requires_grad为True,表示需要计算它的梯度
b = torch.randn(3, requires_grad=True) # 参数b,随机初始化,设置requires_grad为True,表示需要计算它的梯度
z = torch.matmul(x, w)+b # 神经网络前向传播
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) # 计算损失函数,binary_cross_entropy_with_logits是二分类损失函数,梯度可以自动计算

张量、函数和计算图

此代码定义了以下 计算图

计算图

在此网络中,wb 是需要优化参数。因此,我们需要能够计算损失函数对这些变量的梯度。为此,我们设置这些张量的 requires_grad 属性。

注意:

  • 您可以在创建张量时设置requires_grad的值,也可以通过使用x.requires_grad_(True)方法在以后设置。

一个应用于张量以构建计算图的函数实际上是 Function 类的对象。此对象知道如何在正向传播时计算函数以及如何在 反向传播 步骤中计算其导数。反向传播函数的引用存储在张量的 grad_fn 属性中。你可以在文档中了解更多有关 Function 的信息。

打印输出变量z和loss的梯度函数。

print(f"变量 z 的梯度函数为:{
     z.grad_fn}")
print(f"变量 loss 的梯度函数为:{
     loss.grad_fn}")

计算梯度

为了优化神经网络中的参数权重,我们需要计算损失函数对参数的导数,即我们需要在一些固定的 xy 值下计算 loss 对 ‘w’ 和 ‘b’ 的导数,也就是 ∂ l o s s ∂ w \frac{\partial loss}{\partial w} wloss ∂ l o s s ∂ b \frac{\partial loss}{\partial b} <

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

「已注销」

不打赏也没关系,点点关注呀

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值