分位数损失(又称 Pinball 损失)简介

         欢迎来到雲闪世界评估指标很多,例如 MSE、MAE、RMSE 等。当我们关心均值或中位数预测时,这些指标非常重要。然而,当我们想要训练模型关注分布中的其他位置时,我们必须使用不同的指标,而这在数据科学博客文章中并不常见。

在本文中,我们将探讨分位数损失(也称为弹球损失),它是分位数回归中的首选指标。

一些入门定义

在解释分位数损失之前,让我们快速浏览几个定义,以确保我们理解一致。

让我们从一个简单的开始。回归类型的算法预测连续变量,例如,它们预测温度、股票价格、最新 iPhone 的需求等。

现在是时候复习一下统计学知识了。α分位数是将给定的一组数字划分为以下值:α × 100% 的数字小于或等于该值,而其余 (1 − α) × 100% 的数字大于或等于该值。

具体来说,第 50 分位数(中位数)将数据分成两部分,一半的值位于其下方,而另一半位于其上方。同样,第 10 分位数表示我们可以找到 10% 的数据的点,而第 90 分位数表示我们可以预期看到 90% 的数据的点。

结合前面两种思想,分位数回归是一种回归分析,用于估计目标变量的条件分位数。因此,它比平均值更全面地展示了变量之间的关系。

让我们看一个例子。假设我们有一个分位数回归模型来预测明天对苹果的需求。我们的模型预测第 90 分位数为 100,这意味着根据该模型,实际需求为 100 或更低的概率为 90%。

分位数损失背后的直觉

在查看损失公式或图之前,让我们先对如何评估分位数(或概率)预测进行一些直观的了解。为了说明这一点,请考虑前面的示例:

如果第 90 分位数预测为 100,则表示实际需求为 100 或更低的概率为 90%。

对于这样的预测,我们预计需求在 90% 的情况下会低于 100。鉴于这种概率陈述,这种预测因低估需求(或任何其他值)而受到的惩罚应该比因高估需求而受到的惩罚更高。这暗示了分位数损失是一种非对称损失函数。

此外,按照这种逻辑,低估的惩罚应该随着分位数的提高而增加。因此,分位数越高,分位数损失对低估的惩罚就越大,对高估的惩罚就越小。

现在让我们考虑一下分位数范围的另一个极端。第 10 分位数预测值为 100,这表明 90% 的时间里,我们预计实际值会高于 100。因此,第 10 分位数的分位数损失函数应该对高估真实值施加比低估更大的惩罚。这将反映出准确捕捉分布中较低值的重要性。

公式

现在我们已经有了一些直觉,让我们看一下分位数损失的公式:

其中 α 是分位数,y 是实际值,y_hat 是预测值,(y — y_hat) 是预测误差。第一种情况(行)代表预测不足,而第二种情况代表预测过度。

现在,为了使其更加直观,让我们检查显示第 10、第 50 和第 90 分位数的分位数损失的图。我们可以使用以下代码片段生成这样的图:

import numpy as np 
import matplotlib.pyplot as plt 

# 定义 pinball 损失函数
def  pinball_loss ( y_true, y_pred, quantile ): 
    return np.where(y_true >= y_pred, quantile * (y_true - y_pred), (quantile - 1 ) * (y_true - y_pred)) 

# 生成一系列预测误差
errors = np.linspace(- 10 , 10 , 400 )   
y_true = 0   

# 分位数
quantiles = [ 0.1 , 0.5 , 0.9 ] 
line_styles = [ '-' , '--' , '-.' ]   

# 绘图
plt.figure(figsize=( 10 , 6 )) 
for q, ls in  zip (quantiles, line_styles): 
    loss = pinball_loss(y_true, errors, q) 
    plt.plot(errors, loss, linestyle=ls, label= f'Quantile {q* 100 : .0 f} ' ) 
    
plt.axhline( 0 , color= 'gray' , linestyle= '--' , linewidth= 0.5 ) 
plt.axvline( 0 , color= 'gray' , linestyle= '--' , linewidth= 0.5 ) 
plt.xlabel( '预测误差 (y_true - y_pred)' ) 
plt.ylabel( 'Pinball Loss' ) 
plt.title( '不同分位数的 Pinball 损失' ) 
plt.legend() 
plt.grid( True ) 
plt.show()

图片来自作者

检查该情节可以得出以下结论:

  • 橙色虚线表示中位数。如您所见,该线围绕零(完美预测)对称。换句话说,使用中位数会为低估和高估分配相同的权重。这也相当于使用 MAE 损失函数。
  • α < 0.5 的值使得过度预测的成本更高。因此,模型倾向于低估目标。
  • α > 0.5 的值使得预测不足的代价更高。因此,模型会倾向于高估目标。
  • α 距离 0.5 越远,低估或高估的可能性就越大。
  • 分位数损失越低,模型表现越好。正如我们已经提到的,损失为 0 代表满分。

旁注:“弹球损失”这个名字来自于损失函数的形状,类似于球在弹球机中弹跳的方式。

让我们看一个具体的例子来了解所有这些是如何结合在一起的。假设我们对第 90 分位数感兴趣。损失函数将如下所示:

因此,对于 α = 0.9,低估将被惩罚 0.9 倍,而高估将被惩罚 0.1 倍。我们可以看到,在这种情况下,低估的惩罚比高估的惩罚严重 9 倍。因此,回归模型将更加关注低估,并且倾向于更频繁地预测更高的值。

平均而言,我们可以预期这种模型在约 90% 的情况下会预测过高,而在其余 10% 的情况下会预测过低。这基本上就是第 90 分位数所代表的意思。

在下图中,我们可以看到分位数损失不对称的具体例子。假设真实值为 100,预测误差为 5(高于和低于),则低估的惩罚比高估的惩罚高出 9 倍。

图片来自作者

对于另一个极端(例如第 10 分位数),可以轻松地重复类似的计算。

何时使用

现在我们知道了什么是分位数/弹球损失,让我们考虑一下它的用例。我们知道它用于评估分位数回归和概率预测,但我们什么时候会使用它呢?以下是几个例子:

  • 预测分位数——当我们想要预测分布的某些分位数而不是平均值时。这在金融风险管理或天气预报等领域很有用。
  • 考虑损失的不对称性——在某些情况下,高估和低估的成本并不相等。例如,在供应链管理中,低估需求可能会导致销售损失。另一方面,高估可能会导致库存过剩。我们可以使用分位数损失来优化这种不对称情况。
  • 预测区间——例如,在时间序列预测中,预测一个范围(例如,第 10 分位数和第 90 分位数之间)可以比点预测提供更多有用的信息。

使用 Midjourney 生成的图像

包起来

在本文中,我们介绍了分位数损失、它是什么以及何时使用它。主要内容如下:

  • 分位数损失是一种不对称的、成本敏感的损失函数,用于训练预测目标变量分布的特定分位数的模型。
  • 对于中位数,它相当于MAE。
  • α 距离 0.5 越远,低估或高估的可能性就越大。
  • 感谢关注雲闪世界。(亚马逊aws谷歌GCP服务协助解决云计算及产业相关解决方案)

     订阅频道(https://t.me/awsgoogvps_Host)
     TG交流群(t.me/awsgoogvpsHost)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值