PyTorch 2.X 中 nn.functional.mse_loss() 函数中 reduction 参数的 “mean“ 和 “sum“ 选项的本质区别和影响

🍉 CSDN 叶庭云https://yetingyun.blog.csdn.net/


首先,我们来回顾一下均方误差(Mean Squared Error, MSE)损失函数的基本概念。MSE 是深度学习中,特别是在回归问题中,最常用的损失函数之一。它用于衡量预测值与真实值之间的平均平方差。

基本公式为:

MSE = 1 n ∑ i = 1 n ( y pred , i − y true , i ) 2 \text{MSE} = \frac{1}{n} \sum_{i=1}^n (y_{\text{pred},i} - y_{\text{true},i})^2 MSE=n1i=1n(ypred,iytrue,i)2

在这个公式中:

  • MSE \text{MSE} MSE 代表均方误差
  • n n n 是样本的总数
  • ∑ i = 1 n \sum_{i=1}^n i=1n 表示从 i = 1 i=1 i=1 n n n 的求和
  • y pred , i y_{\text{pred},i} ypred,i 是第 i i i 个样本的预测值
  • y true , i y_{\text{true},i} ytrue,i 是第 i i i 个样本的真实值
  • ( y pred , i − y true , i ) 2 (y_{\text{pred},i} - y_{\text{true},i})^2 (ypred,iytrue,i)2 是预测值与真实值之差的平方

此公式计算了所有样本预测误差平方的平均值,是评估回归模型性能或重建数据效果的常用指标。MSE 值越小,表明模型的预测越准确。

一些额外的说明:

  1. 该公式默认所有样本权重相等。然而,在某些情况下,可能需要根据不同样本的重要性分配不同的权重。

  2. MSE 的单位是目标变量单位的平方。举例而言,若预测对象为美元计价的房价,则 MSE 的单位将是美元的平方。

  3. MSE 对异常值敏感,因为其计算包含平方项,导致大误差被显著放大。因此,在包含显著异常值的数据集上应用 MSE 时需谨慎。

  4. MSE 的平方根被称为 RMSE(均方根误差),它与原始目标变量单位相同,便于理解和解释。

  5. 实际应用中,通常利用训练集训练模型,随后在验证集或测试集上计算 MSE,以评估模型的泛化性能。

PyTorch 中的 nn.functional.mse_loss() 函数。在 PyTorch 中,nn.functional.mse_loss() 函数用于计算均方误差(MSE)损失。其基本用法如下:

import torch.nn.functional as F

loss = F.mse_loss(input, target, reduction='mean')

这里,input 和 target 是需要比较的张量(Tensor),而 reduction=‘mean’ 参数指定了损失值的聚合方式,即计算所有元素损失的平均值。你也可以根据需要调整 reduction 参数的值,比如设置为 ‘sum’ 来计算损失的总和。

reduction 参数的作用。reduction 参数用于控制每个元素损失的汇总方式。它有三个选项:‘none’、‘mean’ 和 ‘sum’。在此,我们主要聚焦于 ‘mean’ 和 ‘sum’ 两个选项。

详解 reduction=“mean”:

当 reduction 设置为 “mean” 时,函数会首先计算所有元素的均方误差(MSE),然后求其平均值。

计算步骤:

  • 计算每个元素的平方误差:(input - target)²
  • 对所有元素的平方误差求和
  • 将总和除以元素的总数(即 batch_size 乘以每个样本的元素数)
  • 数学表达式: loss_mean = Σ(input - target)² / (batch_size * num_elements_per_sample)

示例:假设 batch_size 为 2,每个样本包含 3 个元素,输入和目标值如下:

  • input = [[1, 2, 3], [4, 5, 6]]
  • target = [[2, 2, 2], [4, 4, 4]]
  • 平方误差:[1, 0, 1, 0, 1, 4]
  • 求和:1 + 0 + 1 + 0 + 1 + 4 = 7
  • 平均:7 / (2 * 3) = 7 / 6 ≈ 1.1667

详解 reduction=“sum”:

当 reduction 设置为 “sum” 时,函数会计算所有元素的 MSE,并直接对这些 MSE 求和,不进行平均处理。

计算步骤:

  • 计算每个元素的平方误差:(input - target)²
  • 对所有元素的平方误差求和
  • 数学表达式: loss_sum = Σ(input - target)²

仍使用上述示例:

  • 平方误差:[1, 0, 1, 0, 1, 4]
  • 求和:1 + 0 + 1 + 0 + 1 + 4 = 7

本质区别:

  • mean:提供的是每个元素的平均平方误差,此值不受批量大小或样本维度变化的影响。
  • sum:给出的是所有元素误差的总和,会随着批量大小和样本维度的增加而相应增加。

梯度大小:

  • mean:产生的梯度相对较小,因此更为稳定。
  • sum:可能产生较大的梯度,特别是在处理大批量或高维输入时。

学习率敏感度:

  • mean:对学习率的变化不太敏感。
  • sum:可能需要较小的学习率来避免过调,特别是在大批量情况下。

批量大小的影响:

  • mean:损失值保持稳定,不受批量大小变化的影响。
  • sum:随着批量大小的增加,损失值也会相应增加。

可解释性:

  • mean:由于代表平均误差,因此更容易进行解释。
  • sum:可能难以直观理解,特别是在处理不同大小的数据集时。

使用建议:

  • 对于大多数应用场景,推荐使用 mean,因为它提供了一个标准化的误差度量标准。
  • 如果你特别关注总体误差而非平均误差,可以考虑使用 sum。
  • 在某些特定情况下,如需要损失函数对样本数量敏感时,sum 可能更为合适。

一个简单的代码示例如下:

import torch
import torch.nn.functional as F

input = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
target = torch.tensor([[2., 2., 2.], [4., 4., 4.]])

loss_mean = F.mse_loss(input, target, reduction='mean')
loss_sum = F.mse_loss(input, target, reduction='sum')

print(f"Mean loss: {loss_mean.item()}")
print(f"Sum loss: {loss_sum.item()}")

输出:

Mean loss: 1.1666666666666667
Sum loss: 7.0

这个例子清晰地展示了 “mean” 与 “sum” 之间的区别:其中,“sum” 表示总误差,而 “mean” 则提供了平均误差。

总结:正确理解和应用 reduction 参数对于 MSE 损失函数的使用和解释至关重要。在决定使用 mean 还是 sum 时,需根据具体需求和问题设定来选择。一般而言,mean 提供了一个更为标准化、便于比较的误差度量方式;而 sum 则在特定场景下可能更为适用。无论选择哪种方式,关键在于明确其含义,并在整个训练流程中保持一致性。


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值