文章最前: 我是Octopus,这个名字来源于我的中文名--章鱼;我热爱编程、热爱算法、热爱开源。所有源码在我的个人github ;这博客是记录我学习的点点滴滴,如果您对 Python、Java、AI、算法有兴趣,可以关注我的动态,一起学习,共同进步。
对于回归预测任务来说,并不是所有时候我们都只追求绝对准确的预测,事实上,我们的预测总是不准确的,所以有时需要一个预测区间,而不是寻找绝对的精度,在这种情况下我们需要分位数 回归——我们预测目标的区间估计。
损失函数
幸运的是,强大的力量lightGBM
使得分位数预测成为可能,分位数回归与一般回归的主要区别在于损失函数,称为 pinball 损失或分位数损失。有一个关于弹球损失的很好的解释,它有公式:
其中y
是实际值,z
是预测值,𝛕 是目标分位数。所以第一眼看到损失函数,我们可以看到,除了分位数等于0.5时,损失函数是不对称的
图中绘制了三个不同的分位数,以分位数0.8为例,当误差为正时(z > y
——预测值高于实际值),损失比误差为负时要小。在另一个世界中,错误越高受到的惩罚越少,这是有道理的 对于高分位数预测,损失函数鼓励更高的预测值,反之亦然,对于低分位数预测。
波士顿房价分位数预测:
import lightgbm as lgb
from sklearn.datasets import load_boston
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from pandas import DataFrame
import matplotlib.pyplot as plt
# 准备数据
boston = load_boston()
x, y = boston.data, boston.target
x_df = DataFrame(x, columns= boston.feature_names)
x_train, x_test, y_train, y_test = train_test_split(x_df, y, test_size=0.15)
# 构造模型
# defining parameters
params = {
'task': 'train',
'boosting': 'gbdt',
'metric': 'quantile',
'objective': 'quantile',
'num_leaves': 10,
'learnnig_rage': 0.05,
'metric': {'l2','l1'},
'verbose': -1
}
# 95分位数
upper = lgb.LGBMRegressor(objective = 'quantile', alpha = 0.95)
upper.fit(x_train, y_train)
upper_pred = upper.predict(x_test)
# 5分位数
lower = lgb.LGBMRegressor(objective = 'quantile', alpha = 1 - 0.95)
lower.fit(x_train, y_train)
lower_pred = lower.predict(x_test)
# 50分位数
upper = lgb.LGBMRegressor(objective = 'quantile', alpha = 0.5)
upper.fit(x_train, y_train)
pred = upper.predict(x_test)
# 数据可视化
plt.figure(figsize=(10, 6))
plt.scatter(y_test, lower_pred, color='limegreen', marker='o', label='lower', lw=0.5, alpha=0.5)
plt.scatter(y_test, pred, color='aqua', marker='x', label='pred', alpha=0.7)
plt.scatter(y_test, upper_pred, color='dodgerblue', marker='o', label='upper', lw=0.5, alpha=0.5)
plt.plot(sorted(y_test), sorted(lower_pred), color='limegreen')
plt.plot(sorted(y_test), sorted(pred), color='red')
plt.plot(sorted(y_test), sorted(upper_pred), color='dodgerblue')
plt.legend()
plt.show()
预测结果展示图: