在统计学习角度,Huber损失函数是一种使用鲁棒性回归的损失函数,它相比均方误差来说,它对异常值不敏感。常常被用于分类问题上。
下面先给出Huber函数的定义:
![](https://img-blog.csdn.net/20141117213955359?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvdTAxMDkyMjE4Ng==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center)
这个函数对于小的a值误差函数是二次的,而对大的值误差函数是线性的。变量a表述residuals,用以描述观察值与预测值之差:
,因此我们可以将上面的表达式写成下面的形式:
![](https://img-blog.csdn.net/20141117214541779?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvdTAxMDkyMjE4Ng==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center)
visualization
import numpy as np
import matplotlib.pyplot as plt
def huber_loss(e, d):
return (abs(e)<=d)*e**2/2 + (e>d)*d*(abs(e)-d/2)
plt.figure(figsize=(6, 4.5), facecolor='w', edgecolor='k')
x = np.arange(-20, 20)
plt.plot(x, x**2/2, label='squared loss', lw=2, 'g')
for d in (10, 5, 3, 1):
plt.plot(x, huber_loss(x, d), label=r'huber loss: $\delta$={}'.format(d), lw=2)
plt.legend(loc='best', frameon=False)
plt.xlabel('standard deviation')
plt.ylabel('loss')
plt.show()
![](https://img-blog.csdn.net/20161129213901604?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQv/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center)
References
[1] Huber Loss Wikipedia
[2]https://www.zhihu.com/question/21018545