一、前言
这篇文章已经讲得特别好了《一分钟理解softmax函数(超简单)》,为了加深理解,我按照自己的理解进行引用修改,并配上代码实现。
前面我们把sigmoid函数引入《机器学习 逻辑回归(1)二分类》中,用于解决是或否的二分类问题。但存在局限性,如果预测结果有多种类别,那怎么处理呢?
首先,我们很容易想到,如果计算结果是各种类别的概率,那就好了。比如说,总共有三个类别A、B、C
,我们通过函数计算,类别A的概率为0.1
、B的概率为0.6
、C的概率为0.3
,那么我们就大概率认定预测结果为B。
这就对我们的函数提出了要求:
1)预测的概率为非负数;
2)各种预测结果概率之和等于1
这就是softmax函数
要做的事。
二、将结果转化为非负数
我们来看下指数函数 y = e x y=e^x y=ex 的图像:
import matplotlib.pyplot as plt
import numpy as np
np.set_printoptions(suppress=True) #numpy不使用科学计数法
x = np.arange(-5, 5, 0.1) #起点,终点,间距
y = np.exp(x)
plt.plot(x, y)
plt.show()
值域为
(
0
,
+
∞
)
(0,+\infty)
(0,+∞),不管是输入正数还是负数,其结果都是正数,这不正是我们需要的么!
s o f t m a x 第一步就是将模型的预测结果转化到指数函数上,这样保证了概率的非负性。 \color{red}{softmax第一步就是将模型的预测结果转化到指数函数上,这样保证了概率的非负性。} softmax第一步就是将模型的预测结果转化到指数函数上,这样保证了概率的非负性。
三、预测结果概率之和等于1
办法是 将转化后的结果除以所有转化后结果之和,可以理解为转化后结果占总数的百分比。 \color{red}{将转化后的结果除以所有转化后结果之和,可以理解为转化后结果占总数的百分比。} 将转化后的结果除以所有转化后结果之和,可以理解为转化后结果占总数的百分比。
四、代码实现
例如:预测结果值分别是
x
1
=
−
3
、
x
2
=
1.5
、
x
3
=
2.7
x_1=-3、x_2=1.5、x_3=2.7
x1=−3、x2=1.5、x3=2.7,我们用softmax函数
将结果转化成概率:
1)将预测结果转化为非负数
y
1
=
e
x
1
=
e
−
3
=
0.05
y_1 = e^{x_1} = e^{-3} = 0.05
y1=ex1=e−3=0.05
y
2
=
e
x
2
=
e
1.5
=
4.48
y_2 = e^{x_2} = e^{1.5} = 4.48
y2=ex2=e1.5=4.48
y
3
=
e
x
3
=
e
2.7
=
14.88
y_3 = e^{x_3} = e^{2.7} = 14.88
y3=ex3=e2.7=14.88
代码:
y=np.exp([-3,1.5,2.7])
print(y)
运行结果:
[ 0.04978707 4.48168907 14.87973172]
2)各种预测结果概率之和等于1
z
1
=
y
1
/
(
y
1
+
y
2
+
y
3
)
=
0.05
/
(
0.05
+
4.48
+
14.88
)
=
0.0026
z_1 = y_1/(y_1+y_2+y_3) = 0.05/(0.05+4.48+14.88) = 0.0026
z1=y1/(y1+y2+y3)=0.05/(0.05+4.48+14.88)=0.0026
z
2
=
y
2
/
(
y
1
+
y
2
+
y
3
)
=
4.48
/
(
0.05
+
4.48
+
14.88
)
=
0.2308
z_2 = y_2/(y_1+y_2+y_3) = 4.48/(0.05+4.48+14.88) = 0.2308
z2=y2/(y1+y2+y3)=4.48/(0.05+4.48+14.88)=0.2308
z
3
=
y
3
/
(
y
1
+
y
2
+
y
3
)
=
14.88
/
(
0.05
+
4.48
+
14.88
)
=
0.7666
z_3 = y_3/(y_1+y_2+y_3) = 14.88/(0.05+4.48+14.88) = 0.7666
z3=y3/(y1+y2+y3)=14.88/(0.05+4.48+14.88)=0.7666
代码:
z=y/y.sum()
print(z)
运行结果:
[0.00256486 0.23088151 0.76655362]
softmax
函数代码实现:
def softmax(x):
ex=np.exp(x)
return ex/ex.sum()
测试:
softmax([-3,1.5,2.7])
运行结果:
array([0.00256486, 0.23088151, 0.76655362])