线性判别分析(Fisher判别分析)
线性判别分析(LDA)是一种经典的线性学习方法。
LDA的思想非常朴素:给定训练样例集,设法将样例投影到一条直线上,使得同类样例的投影点尽可能接近、异类样例的投影点尽可能远离;在对新样本进行分类时,将其投影到同样的这条直线上,再根据投影点的位置来确定新样本的类别。
“+”、“_” 分别代表正例和反例,椭圆表示数据簇的外轮廓,虚线表示投影,红色实心圆和实心三角形分别表示两类样本投影后的中心点。
D
=
{
(
x
i
,
y
i
)
}
i
=
1
m
D = \left \{ \left ( x_{i} ,y_{i}\right ) \right \}_{i=1}^{m}
D={(xi,yi)}i=1m,
y
i
=
∈
{
0
,
1
}
y_{i}=\in \left \{ 0,1 \right \}
yi=∈{0,1},令
X
i
X_{i}
Xi、
m
u
i
mu _{i}
mui 、
S
i
g
m
a
i
Sigma i
Sigmai分别表示第
i
∈
{
0
,
1
}
i\in \left \{ 0,1 \right \}
i∈{0,1}类示例的集合、均值向量、协方差矩阵。若将数据投影到直线
ω
\omega
ω上,则两类样本的中心在直线上的投影分别为
ω
T
μ
0
\omega ^{T}\mu _{0}
ωTμ0和
ω
T
μ
1
\omega ^{T}\mu _{1}
ωTμ1;若将所有样本点都投影到直线上,则两类样本的协方差分别为
ω
T
Σ
0
ω
\omega ^{T}\Sigma _{0}\omega
ωTΣ0ω和
ω
T
Σ
1
ω
\omega ^{T}\Sigma _{1}\omega
ωTΣ1ω由于直线是一维空间,因此均为实数。
类内散度矩阵:
S
ω
=
Σ
0
+
Σ
1
S_{\omega }= \Sigma _{0}+\Sigma _{1}
Sω=Σ0+Σ1
=
∑
x
∈
X
0
(
x
−
μ
0
)
(
x
−
μ
0
)
T
+
s
u
m
x
∈
X
1
(
x
−
μ
1
)
(
x
−
μ
1
)
T
\sum_{x\in X_{0}}^{}\left ( x-\mu _{0} \right )\left ( x-\mu _{0} \right )^{T}+sum_{x\in X_{1}}^{}\left ( x-\mu _{1} \right )\left ( x-\mu _{1} \right )^{T}
∑x∈X0(x−μ0)(x−μ0)T+sumx∈X1(x−μ1)(x−μ1)T
类间散度矩阵:
S
b
=
(
μ
0
−
μ
1
)
(
μ
0
−
μ
1
)
T
S_{b}=\left ( \mu _{0}-\mu _{1} \right )\left ( \mu _{0}-\mu _{1} \right )^{T}
Sb=(μ0−μ1)(μ0−μ1)T
最大化目标:
J
=
ω
T
S
b
ω
ω
T
S
ω
ω
J=\frac{\omega ^{T}S_{b}\omega }{\omega ^{T}S_{\omega }\omega}
J=ωTSωωωTSbω
确定
ω
\omega
ω:
令
ω
T
S
ω
ω
=
1
\omega ^{T}S_{\omega }\omega=1
ωTSωω=1
则
J
=
ω
T
S
b
ω
ω
T
S
ω
ω
J=\frac{\omega ^{T}S_{b}\omega }{\omega ^{T}S_{\omega }\omega}
J=ωTSωωωTSbω等价于
m
i
n
ω
\underset{\omega}{min}
ωmin
−
ω
T
S
b
ω
-\omega ^{T}S_{b}\omega
−ωTSbω
s
.
t
.
s_{.}t_{.}
s.t.
ω
T
S
ω
ω
=
1
\omega ^{T}S_{\omega }\omega=1
ωTSωω=1
由拉格朗日乘子法,上式等价于
S
b
ω
=
λ
S
ω
ω
S_{b}\omega=\lambda S_{\omega }\omega
Sbω=λSωω
令
S
b
ω
=
λ
(
μ
0
−
μ
1
)
S_{b}\omega =\lambda \left ( \mu _{0}-\mu _{1} \right )
Sbω=λ(μ0−μ1)
则
ω
=
S
ω
−
1
(
μ
0
−
μ
1
)
\omega =S_{\omega }^{-1}\left ( \mu _{0}-\mu _{1} \right )
ω=Sω−1(μ0−μ1)
Python编程
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# 读取数据
path=r'Iris.csv'
df = pd.read_csv(path, header=0)
Iris1=df.values[0:50,0:4]
Iris2=df.values[50:100,0:4]
Iris3=df.values[100:150,0:4]
# 类均值向量
m1=np.mean(Iris1,axis=0)
m2=np.mean(Iris2,axis=0)
m3=np.mean(Iris3,axis=0)
# 各类内离散度矩阵
s1=np.zeros((4,4))
s2=np.zeros((4,4))
s3=np.zeros((4,4))
for i in range(0,30,1):
a=Iris1[i,:]-m1
a=np.array([a])
b=a.T
s1=s1+np.dot(b,a)
for i in range(0,30,1):
c=Iris2[i,:]-m2
c=np.array([c])
d=c.T
s2=s2+np.dot(d,c)
for i in range(0,30,1):
a=Iris3[i,:]-m3
a=np.array([a])
b=a.T
s3=s3+np.dot(b,a)
# 总类内离散矩阵
sw12=s1+s2
sw13=s1+s3
sw23=s2+s3
a=np.array([m1-m2])
sw12=np.array(sw12,dtype='float')
sw13=np.array(sw13,dtype='float')
sw23=np.array(sw23,dtype='float')
# 转置矩阵
a = np.array([m1-m2]).T
b = np.array([m1-m3]).T
c = np.array([m2-m3]).T
# 投影方向
w12=(np.dot(np.linalg.inv(sw12),a)).T
w13=(np.dot(np.linalg.inv(sw13),b)).T
w23=(np.dot(np.linalg.inv(sw23),c)).T
# 判别函数以及阈值T
T12=-0.5*(np.dot(np.dot((m1+m2),np.linalg.inv(sw12)),a))
T13=-0.5*(np.dot(np.dot((m1+m3),np.linalg.inv(sw13)),b))
T23=-0.5*(np.dot(np.dot((m2+m3),np.linalg.inv(sw23)),c))
kind1=0
kind2=0
kind3=0
newiris1=[]
newiris2=[]
newiris3=[]
for i in range(30,49):
x=Iris1[i,:]
x=np.array([x])
g12=np.dot(w12,x.T)+T12
g13=np.dot(w13,x.T)+T13
g23=np.dot(w23,x.T)+T23
if g12>0 and g13>0:
newiris1.extend(x)
kind1=kind1+1
elif g12<0 and g23>0:
newiris2.extend(x)
elif g13<0 and g23<0 :
newiris3.extend(x)
for i in range(30,49):
x=Iris2[i,:]
x=np.array([x])
g12=np.dot(w12,x.T)+T12
g13=np.dot(w13,x.T)+T13
g23=np.dot(w23,x.T)+T23
if g12>0 and g13>0:
newiris1.extend(x)
elif g12<0 and g23>0:
newiris2.extend(x)
kind2=kind2+1
elif g13<0 and g23<0 :
newiris3.extend(x)
for i in range(30,49):
x=Iris3[i,:]
x=np.array([x])
g12=np.dot(w12,x.T)+T12
g13=np.dot(w13,x.T)+T13
g23=np.dot(w23,x.T)+T23
if g12>0 and g13>0:
newiris1.extend(x)
elif g12<0 and g23>0:
newiris2.extend(x)
elif g13<0 and g23<0 :
newiris3.extend(x)
kind3=kind3+1
correct=(kind1+kind2+kind3)/60
print("样本类内离散度矩阵S1:",s1,'\n')
print("样本类内离散度矩阵S2:",s2,'\n')
print("样本类内离散度矩阵S3:",s3,'\n')
print('-----------------------------------------------------------------------------------------------')
print("总体类内离散度矩阵Sw12:",sw12,'\n')
print("总体类内离散度矩阵Sw13:",sw13,'\n')
print("总体类内离散度矩阵Sw23:",sw23,'\n')
print('-----------------------------------------------------------------------------------------------')
print('判断出来的综合正确率:',correct*100,'%')
样本类内离散度矩阵S1: [[4.084080000000003 2.9814400000000005 0.5409999999999995
0.4941599999999999]
[2.9814400000000005 3.6879200000000028 -0.025000000000000428
0.5628800000000002]
[0.5409999999999995 -0.025000000000000428 1.0829999999999995 0.19]
[0.4941599999999999 0.5628800000000002 0.19 0.30832000000000004]]
样本类内离散度矩阵S2: [[8.316120000000005 2.7365199999999987 5.568960000000003
1.7302799999999998]
[2.7365199999999987 3.09192 2.49916 1.3588799999999999]
[5.568960000000003 2.49916 6.258680000000002 2.2232399999999997]
[1.7302799999999998 1.3588799999999999 2.2232399999999997
1.3543200000000004]]
样本类内离散度矩阵S3: [[14.328471470220745 3.1402832153269435 11.94600583090379
1.3147563515201988]
[3.1402832153269435 3.198721366097457 2.239650145772593
1.2317617659308615]
[11.94600583090379 2.239650145772593 11.600816326530618
1.4958892128279884]
[1.3147563515201988 1.2317617659308615 1.4958892128279884
1.6810578925447726]]
-----------------------------------------------------------------------------------------------
总体类内离散度矩阵Sw12: [[12.4002 5.71796 6.10996 2.22444]
[ 5.71796 6.77984 2.47416 1.92176]
[ 6.10996 2.47416 7.34168 2.41324]
[ 2.22444 1.92176 2.41324 1.66264]]
总体类内离散度矩阵Sw13: [[18.41255147 6.12172322 12.48700583 1.80891635]
[ 6.12172322 6.88664137 2.21465015 1.79464177]
[12.48700583 2.21465015 12.68381633 1.68588921]
[ 1.80891635 1.79464177 1.68588921 1.98937789]]
总体类内离散度矩阵Sw23: [[22.64459147 5.87680322 17.51496583 3.04503635]
[ 5.87680322 6.29064137 4.73881015 2.59064177]
[17.51496583 4.73881015 17.85949633 3.71912921]
[ 3.04503635 2.59064177 3.71912921 3.03537789]]
-----------------------------------------------------------------------------------------------
判断出来的综合正确率: 91.66666666666666 %
参考文献
周志华 《机器学习》
http://bob0118.club/?p=266