TL; TR
环节 | 核心任务 | 代码位置 | 关键操作 |
---|---|---|---|
正向传播 | 计算预测值和损失 | self.step(batch) 中的模型前向计算 | model(inputs) 、loss_fn(outputs, targets) |
反向传播 | 计算梯度并更新参数 | trainer.fit() 内部 | loss.backward() 、optimizer.step() |
一、正向传播(Forward Propagation)的位置
1. 核心定义
正向传播是数据从输入层流经网络各层,最终得到输出预测值的过程,主要完成“计算预测值”的任务。
2. 代码中的体现
正向传播的关键步骤大概率包含在 self.step(batch)
方法中。具体来说:
def training_step(self, batch, batch_idx):
# ... 其他逻辑
step_loss = self.step(batch) # 正向传播和损失计算的核心位置
# ... 损失记录逻辑
self.step(batch)
的内部逻辑(假设实现如下):def step(self, batch): inputs, targets = batch # 1. 模型前向计算(正向传播的核心) outputs = self.model(inputs) # 输入数据通过模型各层,得到预测值 # 2. 计算损失(正向传播的后续步骤) loss = self.loss_fn(outputs, targets) return loss
- 正向传播的具体流程:
- 输入数据
inputs
进入模型,依次经过各层神经网络(如卷积层、全连接层等)。 - 每层根据当前参数对输入进行计算,生成中间特征或最终输出
outputs
。 - 输出值与真实标签
targets
对比,通过损失函数(如交叉熵、均方差)计算损失loss
。
- 输入数据
二、反向传播(Backpropagation)的位置
1. 核心定义
反向传播是从损失函数出发,反向计算梯度并更新模型参数的过程,主要完成“优化参数”的任务。
2. 代码中的体现
反向传播的触发和执行通常由训练框架(如PyTorch Lightning)自动管理,具体位置可能在:
trainer.fit()
内部:框架会自动处理反向传播和参数更新,无需手动编写梯度计算代码。- 隐含的梯度计算流程:
- 当
step_loss
被计算时(正向传播的最后一步),PyTorch会自动构建计算图,记录所有操作的梯度依赖关系。 - 框架在后台调用
loss.backward()
计算梯度(反向传播的核心),并通过优化器(如SGD、Adam)执行optimizer.step()
更新参数。
- 当
3. 反向传播的关键逻辑(框架内部流程):
# 假设框架内部的反向传播逻辑(实际由trainer.fit管理)
def _run_backward(self, loss, optimizer):
# 1. 清空之前的梯度(避免累积)
optimizer.zero_grad()
# 2. 反向传播计算梯度(从损失函数反向推导各参数的梯度)
loss.backward()
# 3. 根据梯度更新参数(反向传播的最终目标)
optimizer.step()
三、补充说明:框架对反向传播的封装
在PyTorch Lightning等高级框架中,反向传播的底层操作(如梯度计算、参数更新)通常被封装在训练器(Trainer
)中,用户无需手动编写。代码中的 trainer.fit()
会自动完成以下步骤:
- 遍历数据加载器(
train_loader
),每次获取一个批次(batch
)。 - 调用
training_step
执行正向传播和损失计算。 - 累积梯度(若启用梯度累积),并在合适时机触发反向传播和参数更新。
四、延伸:手动实现反向传播的示例(非框架场景)
如果不使用框架,手动实现反向传播的代码结构如下,可帮助理解核心逻辑:
# 正向传播
outputs = model(inputs)
loss = loss_fn(outputs, targets)
# 反向传播
optimizer.zero_grad() # 清空梯度
loss.backward() # 计算梯度(反向传播)
optimizer.step() # 更新参数
这些步骤被 trainer.fit()
隐式处理,因此反向传播的具体代码不可见,但逻辑上存在于框架内部。
通过以上分析,可以明确:正向传播的核心在 self.step(batch)
中的模型前向计算,反向传播由 trainer.fit()
框架自动管理,通过损失函数触发梯度计算和参数更新。