Logistic Regression
对于一个二分类问题而言
y ^ = s i g m o i d ( w T x + b ) = σ ( w T x + b ) ) \hat{y}=sigmoid(w^Tx+b)=\sigma(w^Tx+b)) y^=sigmoid(wTx+b)=σ(wTx+b))
表示样本 x x x label为1的概率,取值范围为 [ 0 , 1 ] [0,1] [0,1]。
其中,
s
i
g
m
o
i
d
(
z
)
=
σ
(
z
)
=
1
(
1
+
e
−
z
)
sigmoid(z)=\sigma(z)=\frac{1}{(1+e^{-z})}
sigmoid(z)=σ(z)=(1+e−z)1
Note:
σ ( z ) ′ = σ ( z ) ( 1 − σ ( z ) ) {\sigma}(z)'=\sigma(z)(1-\sigma(z)) σ(z)′=σ(z)(1−σ(z))
则通过上述模型可以得出
P ( y = 1 ∣ x ) = 1 1 + e − x P(y=1|x)=\frac{1}{1+e^{-x}} P(y=1∣x)=1+e−x1
P ( y = 0 ∣ x ) = e − x 1 + e − x = 1 − P ( y = 1 ∣ x ) P(y=0|x)=\frac{e^{-x}}{1+e^{-x}}=1-P(y=1|x) P(y=0∣x)=1+e−xe−x=1−P(y=1∣x)
另一个角度
一个事件发生的概率p与不发生的概率的比值称为该事件的几率
s
m
a
l
l
p
1
−
p
(
o
d
d
s
)
small \frac{p}{1-p} (odds)
small1−pp(odds)。
逻辑斯蒂回归模型即 Y = 1 Y=1 Y=1的对数几率是输入 x x x 的线性函数(统计学习方法)。
Loss Function
一般经验来说,使用均方误差(mean squared error)来衡量Loss Function: L ( y , y ^ ) = 1 2 ( y − y ^ ) 2 L(y,\hat{y})=\frac{1}{2}(y-\hat{y})^2 L(y,y^)=21(y−y^)2 .
但是,对于logistic regression 来说,一般不适用均方误差来作为Loss Function,这是因为:
上面的均方误差损失函数一般是非凸函数(non-convex),其在使用梯度下降算法的时候,容易得到局部最优解,而不是全局最优解。因此要选择凸函数(二阶导大于等于0)。
使用MSE的另一个缺点就是其偏导值在输出概率值接近0或者接近1的时候非常小,这可能会造成模型刚开始训练时,偏导值几乎消失。
这里选择的损失函数交叉熵(信息论)损失函数:
L ( y ^ , y ) = − ( y l o g ( y ^ ) + ( 1 − y ) l o g ( 1 − y ^ ) ) L(\hat{y},y)=-(ylog(\hat{y})+(1-y)log(1-\hat{y})) L(y^,y)=−(ylog(y^)+(1−y)log(1−y^))
网上找了很多博客也没有推导交叉熵损失函数的凸性的博文,所以下面我来推导一下:
这里为了推导方便,假设 x ∈ R 1 \small x\in R^{1} x∈R1
首先我们推导为什么MSE不是凸函数
L
(
y
,
y
^
)
=
1
2
(
y
−
y
^
)
2
L(y,\hat{y})=\frac{1}{2}(y-\hat{y})^2
L(y,y^)=21(y−y^)2
∂ L ( w , b ) ∂ w = ∂ L ( w , b ) ∂ y ^ ∂ y ^ ∂ w = ( y ^ − y ) y ^ ( 1 − y ^ ) x = ( − y ^ 3 + ( 1 + y ) y ^ 2 − y y ^ ) x \frac{\partial L(w,b)}{\partial w}=\frac{\partial L(w,b)}{\partial \hat{y}}\frac{\partial \hat{y}}{\partial w}=(\hat{y}-y)\hat{y}(1-\hat{y})x=(-\hat{y}^3+(1+y)\hat{y}^2-y\hat{y})x ∂w∂L(w,b)=∂y^∂L(w,b)∂w∂y^=(y^−y)y^(1−y^)x=(−y^3+(1+y)y^2−yy^)x
∂ 2 L ( w , b ) ∂ w 2 = ∂ ∂ w ( ∂ L ( w , b ) ∂ w ) = ( − 3 y ^ 2 + 2 ( 1 + y ) y ^ − y ) y ^ ( 1 − y ^ ) x 2 \small \frac{\partial ^{2}L(w,b)}{\partial w^{2}}=\frac{\partial}{\partial w}(\frac{\partial L(w,b)}{\partial w})=(-3\hat{y}^2+2(1+y)\hat{y}-y)\hat{y}(1-\hat{y})x^2 ∂w2∂2L(w,b)=∂w∂(∂w∂L(w,b))=(−3y^2+2(1+y)y^−y)y^(1−y^)x2不能保证大于等于0
同理对于 b \small b b有,
∂ 2 L ( w , b ) ∂ b 2 = ∂ ∂ b ( ∂ L ( w , b ) ∂ b ) = ( − 3 y ^ 2 + 2 ( 1 + y ) y ^ − y ) y ^ ( 1 − y ^ ) \small \frac{\partial ^{2}L(w,b)}{\partial b^{2}}=\frac{\partial}{\partial b}(\frac{\partial L(w,b)}{\partial b})=(-3\hat{y}^2+2(1+y)\hat{y}-y)\hat{y}(1-\hat{y}) ∂b2∂2L(w,b)=∂b∂(∂b∂L(w,b))=(−3y^2+2(1+y)y^−y)y^(1−y^)不能保证大于等于0
证毕。
再推导为什么交叉熵损失函数是凸函数:
L ( y ^ , y ) = − ( y l o g ( y ^ ) + ( 1 − y ) l o g ( 1 − y ^ ) ) L(\hat{y},y)=-(ylog(\hat{y})+(1-y)log(1-\hat{y})) L(y^,y)=−(ylog(y^)+(1−y)log(1−y^))
∂ L ( w , b ) ∂ w = ∂ L ( w , b ) ∂ y ^ ∂ y ^ ∂ w = − ( y y ^ − 1 − y 1 − y ^ ) y ^ ( 1 − y ^ ) x = ( y ^ − y ) x \frac{\partial L(w,b)}{\partial w}=\frac{\partial L(w,b)}{\partial \hat{y}}\frac{\partial \hat{y}}{\partial w}=-(\frac{y}{\hat{y}}-\frac{1-y}{1-\hat{y}})\hat{y}(1-\hat{y})x=(\hat{y}-y)x ∂w∂L(w,b)=∂y^∂L(w,b)∂w∂y^=−(y^y−1−y^1−y)y^(1−y^)x=(y^−y)x
∂ 2 L ( w , b ) ∂ w 2 = ∂ ∂ w ( ∂ L ( w , b ) ∂ w ) = x ∂ y ^ ∂ w = y ^ ( 1 − y ^ ) x 2 ≥ 0 \small \frac{\partial ^{2}L(w,b)}{\partial w^{2}}=\frac{\partial}{\partial w}(\frac{\partial L(w,b)}{\partial w})=x\frac{\partial \hat{y}}{\partial w}=\hat{y}(1-\hat{y})x^2\geq 0 ∂w2∂2L(w,b)=∂w∂(∂w∂L(w,b))=x∂w∂y^=y^(1−y^)x2≥0
对于 b \small b b同理有,
∂ L ( w , b ) ∂ b = ∂ L ( w , b ) ∂ y ^ ∂ y ^ ∂ b = − ( y y ^ − 1 − y 1 − y ^ ) y ^ ( 1 − y ^ ) = y ^ − y \frac{\partial L(w,b)}{\partial b}=\frac{\partial L(w,b)}{\partial \hat{y}}\frac{\partial \hat{y}}{\partial b}=-(\frac{y}{\hat{y}}-\frac{1-y}{1-\hat{y}})\hat{y}(1-\hat{y})=\hat{y}-y ∂b∂L(w,b)=∂y^∂L(w,b)∂b∂y^=−(y^y−1−y^1−y)y^(1−y^)=y^−y
∂ 2 L ( w , b ) ∂ b 2 = ∂ ∂ b ( ∂ L ( w , b ) ∂ b ) = ∂ y ^ ∂ b = y ^ ( 1 − y ^ ) ≥ 0 \small \frac{\partial ^{2}L(w,b)}{\partial b^{2}}=\frac{\partial}{\partial b}(\frac{\partial L(w,b)}{\partial b})=\frac{\partial \hat{y}}{\partial b}=\hat{y}(1-\hat{y})\geq 0 ∂b2∂2L(w,b)=∂b∂(∂b∂L(w,b))=∂b∂y^=y^(1−y^)≥0
证毕。