Pytorch Lightning 进阶 2 - 正向传播和反向传播具体的代码位置

TL; TR

环节核心任务代码位置关键操作
正向传播计算预测值和损失self.step(batch) 中的模型前向计算model(inputs)loss_fn(outputs, targets)
反向传播计算梯度并更新参数trainer.fit() 内部loss.backward()optimizer.step()

一、正向传播(Forward Propagation)的位置

1. 核心定义

正向传播是数据从输入层流经网络各层,最终得到输出预测值的过程,主要完成“计算预测值”的任务。

2. 代码中的体现

正向传播的关键步骤大概率包含在 self.step(batch) 方法中。具体来说:

def training_step(self, batch, batch_idx):
    # ... 其他逻辑
    step_loss = self.step(batch)  # 正向传播和损失计算的核心位置
    # ... 损失记录逻辑
  • self.step(batch) 的内部逻辑(假设实现如下):
    def step(self, batch):
        inputs, targets = batch
        # 1. 模型前向计算(正向传播的核心)
        outputs = self.model(inputs)  # 输入数据通过模型各层,得到预测值
        # 2. 计算损失(正向传播的后续步骤)
        loss = self.loss_fn(outputs, targets)
        return loss
    
  • 正向传播的具体流程
    1. 输入数据 inputs 进入模型,依次经过各层神经网络(如卷积层、全连接层等)。
    2. 每层根据当前参数对输入进行计算,生成中间特征或最终输出 outputs
    3. 输出值与真实标签 targets 对比,通过损失函数(如交叉熵、均方差)计算损失 loss

二、反向传播(Backpropagation)的位置

1. 核心定义

反向传播是从损失函数出发,反向计算梯度并更新模型参数的过程,主要完成“优化参数”的任务。

2. 代码中的体现

反向传播的触发和执行通常由训练框架(如PyTorch Lightning)自动管理,具体位置可能在:

  • trainer.fit() 内部:框架会自动处理反向传播和参数更新,无需手动编写梯度计算代码。
  • 隐含的梯度计算流程
    1. step_loss 被计算时(正向传播的最后一步),PyTorch会自动构建计算图,记录所有操作的梯度依赖关系。
    2. 框架在后台调用 loss.backward() 计算梯度(反向传播的核心),并通过优化器(如SGD、Adam)执行 optimizer.step() 更新参数。
3. 反向传播的关键逻辑(框架内部流程)
# 假设框架内部的反向传播逻辑(实际由trainer.fit管理)
def _run_backward(self, loss, optimizer):
    # 1. 清空之前的梯度(避免累积)
    optimizer.zero_grad()
    # 2. 反向传播计算梯度(从损失函数反向推导各参数的梯度)
    loss.backward()
    # 3. 根据梯度更新参数(反向传播的最终目标)
    optimizer.step()

三、补充说明:框架对反向传播的封装

在PyTorch Lightning等高级框架中,反向传播的底层操作(如梯度计算、参数更新)通常被封装在训练器(Trainer)中,用户无需手动编写。代码中的 trainer.fit() 会自动完成以下步骤:

  1. 遍历数据加载器(train_loader),每次获取一个批次(batch)。
  2. 调用 training_step 执行正向传播和损失计算。
  3. 累积梯度(若启用梯度累积),并在合适时机触发反向传播和参数更新。

四、延伸:手动实现反向传播的示例(非框架场景)

如果不使用框架,手动实现反向传播的代码结构如下,可帮助理解核心逻辑:

# 正向传播
outputs = model(inputs)
loss = loss_fn(outputs, targets)

# 反向传播
optimizer.zero_grad()  # 清空梯度
loss.backward()        # 计算梯度(反向传播)
optimizer.step()       # 更新参数

这些步骤被 trainer.fit() 隐式处理,因此反向传播的具体代码不可见,但逻辑上存在于框架内部。

通过以上分析,可以明确:正向传播的核心在 self.step(batch) 中的模型前向计算,反向传播由 trainer.fit() 框架自动管理,通过损失函数触发梯度计算和参数更新。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值