欢迎来到雲闪世界。评估指标很多,例如 MSE、MAE、RMSE 等。当我们关心均值或中位数预测时,这些指标非常重要。然而,当我们想要训练模型关注分布中的其他位置时,我们必须使用不同的指标,而这在数据科学博客文章中并不常见。
在本文中,我们将探讨分位数损失(也称为弹球损失),它是分位数回归中的首选指标。
一些入门定义
在解释分位数损失之前,让我们快速浏览几个定义,以确保我们理解一致。
让我们从一个简单的开始。回归类型的算法预测连续变量,例如,它们预测温度、股票价格、最新 iPhone 的需求等。
现在是时候复习一下统计学知识了。α分位数是将给定的一组数字划分为以下值:α × 100% 的数字小于或等于该值,而其余 (1 − α) × 100% 的数字大于或等于该值。
具体来说,第 50 分位数(中位数)将数据分成两部分,一半的值位于其下方,而另一半位于其上方。同样,第 10 分位数表示我们可以找到 10% 的数据的点,而第 90 分位数表示我们可以预期看到 90% 的数据的点。
结合前面两种思想,分位数回归是一种回归分析,用于估计目标变量的条件分位数。因此,它比平均值更全面地展示了变量之间的关系。
让我们看一个例子。假设我们有一个分位数回归模型来预测明天对苹果的需求。我们的模型预测第 90 分位数为 100,这意味着根据该模型,实际需求为 100 或更低的概率为 90%。
分位数损失背后的直觉
在查看损失公式或图之前,让我们先对如何评估分位数(或概率)预测进行一些直观的了解。为了说明这一点,请考虑前面的示例:
如果第 90 分位数预测为 100,则表示实际需求为 100 或更低的概率为 90%。
对于这样的预测,我们预计需求在 90% 的情况下会低于 100。鉴于这种概率陈述,这种预测因低估需求(或任何其他值)而受到的惩罚应该比因高估需求而受到的惩罚更高。这暗示了分位数损失是一种非对称损失函数。
此外,按照这种逻辑,低估的惩罚应该随着分位数的提高而增加。因此,分位数越高,分位数损失对低估的惩罚就越大,对高估的惩罚就越小。
现在让我们考虑一下分位数范围的另一个极端。第 10 分位数预测值为 100,这表明 90% 的时间里,我们预计实际值会高于 100。因此,第 10 分位数的分位数损失函数应该对高估真实值施加比低估更大的惩罚。这将反映出准确捕捉分布中较低值的重要性。
公式
现在我们已经有了一些直觉,让我们看一下分位数损失的公式:
其中 α 是分位数,y 是实际值,y_hat 是预测值,(y — y_hat) 是预测误差。第一种情况(行)代表预测不足,而第二种情况代表预测过度。
现在,为了使其更加直观,让我们检查显示第 10、第 50 和第 90 分位数的分位数损失的图。我们可以使用以下代码片段生成这样的图:
import numpy as np
import matplotlib.pyplot as plt
# 定义 pinball 损失函数
def pinball_loss ( y_true, y_pred, quantile ):
return np.where(y_true >= y_pred, quantile * (y_true - y_pred), (quantile - 1 ) * (y_true - y_pred))
# 生成一系列预测误差
errors = np.linspace(- 10 , 10 , 400 )
y_true = 0
# 分位数
quantiles = [ 0.1 , 0.5 , 0.9 ]
line_styles = [ '-' , '--' , '-.' ]
# 绘图
plt.figure(figsize=( 10 , 6 ))
for q, ls in zip (quantiles, line_styles):
loss = pinball_loss(y_true, errors, q)
plt.plot(errors, loss, linestyle=ls, label= f'Quantile {q* 100 : .0 f} ' )
plt.axhline( 0 , color= 'gray' , linestyle= '--' , linewidth= 0.5 )
plt.axvline( 0 , color= 'gray' , linestyle= '--' , linewidth= 0.5 )
plt.xlabel( '预测误差 (y_true - y_pred)' )
plt.ylabel( 'Pinball Loss' )
plt.title( '不同分位数的 Pinball 损失' )
plt.legend()
plt.grid( True )
plt.show()
图片来自作者
检查该情节可以得出以下结论:
- 橙色虚线表示中位数。如您所见,该线围绕零(完美预测)对称。换句话说,使用中位数会为低估和高估分配相同的权重。这也相当于使用 MAE 损失函数。
- α < 0.5 的值使得过度预测的成本更高。因此,模型倾向于低估目标。
- α > 0.5 的值使得预测不足的代价更高。因此,模型会倾向于高估目标。
- α 距离 0.5 越远,低估或高估的可能性就越大。
- 分位数损失越低,模型表现越好。正如我们已经提到的,损失为 0 代表满分。
旁注:“弹球损失”这个名字来自于损失函数的形状,类似于球在弹球机中弹跳的方式。
让我们看一个具体的例子来了解所有这些是如何结合在一起的。假设我们对第 90 分位数感兴趣。损失函数将如下所示:
因此,对于 α = 0.9,低估将被惩罚 0.9 倍,而高估将被惩罚 0.1 倍。我们可以看到,在这种情况下,低估的惩罚比高估的惩罚严重 9 倍。因此,回归模型将更加关注低估,并且倾向于更频繁地预测更高的值。
平均而言,我们可以预期这种模型在约 90% 的情况下会预测过高,而在其余 10% 的情况下会预测过低。这基本上就是第 90 分位数所代表的意思。
在下图中,我们可以看到分位数损失不对称的具体例子。假设真实值为 100,预测误差为 5(高于和低于),则低估的惩罚比高估的惩罚高出 9 倍。
图片来自作者
对于另一个极端(例如第 10 分位数),可以轻松地重复类似的计算。
何时使用
现在我们知道了什么是分位数/弹球损失,让我们考虑一下它的用例。我们知道它用于评估分位数回归和概率预测,但我们什么时候会使用它呢?以下是几个例子:
- 预测分位数——当我们想要预测分布的某些分位数而不是平均值时。这在金融风险管理或天气预报等领域很有用。
- 考虑损失的不对称性——在某些情况下,高估和低估的成本并不相等。例如,在供应链管理中,低估需求可能会导致销售损失。另一方面,高估可能会导致库存过剩。我们可以使用分位数损失来优化这种不对称情况。
- 预测区间——例如,在时间序列预测中,预测一个范围(例如,第 10 分位数和第 90 分位数之间)可以比点预测提供更多有用的信息。
使用 Midjourney 生成的图像
包起来
在本文中,我们介绍了分位数损失、它是什么以及何时使用它。主要内容如下: