1,介绍
线性判别分析(LDA):将数据投影到一条直线上,同类的点尽可能接近,异类点尽可能远离。
如下图:
直观的可以感觉有图分类比左图好。
2,瑞利商与广益瑞利商:
瑞利商函数如下:
其有一个性质为即它的最大值等于矩阵AA最大的特征值,而最小值等于矩阵AA的最小的特征值,即:
而LDA中,我们假设
向量均值的表达式为:
协方差矩阵为:
欲使同类点尽可能进,则协方差须尽可能小,异类点尽可能远,则类中心距离尽可能大,可得:
由矩阵性质可知,在非零情况下,矩阵跟自己的转置相乘必为正定矩阵。可以令:
则上式重写为:
我们再令,,代入可得:
进而,我们相当于求矩阵的最大特征值,而对应的w就是其对应的特征向量。二类问题,我们可以确定的方向恒为,我们假设:,将其代入获得,从而获取。
数据:
Idx,density,ratio_sugar,label 1,0.697,0.46,1 2,0.774,0.376,1 3,0.634,0.264,1 4,0.608,0.318,1 5,0.556,0.215,1 6,0.403,0.237,1 7,0.481,0.149,1 8,0.437,0.211,1 9,0.666,0.091,0 10,0.243,0.267,1 11,0.245,0.057,0 12,0.343,0.099,0 13,0.639,0.161,0 14,0.657,0.198,0 15,0.36,0.37,1 16,0.593,0.042,0 17,0.719,0.103,0
python代码:
import matplotlib.pyplot as plt import numpy as np import csv import pandas as pd def LoadDataSet(): watermelon = pd.read_csv('watermelon_3a.csv', usecols=['density', 'ratio_sugar', 'label']) return watermelon def calulateW(df): df1=df[df.label==1] #取label=1的数据(正类) df2=df[df.label==0] #取label=0的数据(负类) X1=df1.values[:,0:2] #取正类 X0=df2.values[:,0:2] #取负类r mean1=np.array([np.mean(X1[:,0]),np.mean(X1[:,1])]) #取正类均值 mean0=np.array([np.mean(X0[:,0]),np.mean(X0[:,1])]) #取负类均值 m1=np.shape(X1)[0] sw=np.zeros(shape=(2,2)) for i in range(m1): xsmean=np.mat(X1[i,:]-mean1) #求正类的(x-u0) sw+=xsmean.transpose()*xsmean #求正类的(x-u0)*(x-u0)^T m0=np.shape(X0)[0] for i in range(m0): xsmean=np.mat(X0[i,:]-mean0) #求负类的(x-u1) sw+=xsmean.transpose()*xsmean #求负类的(x-u1)*(x-u1)^T w=(mean0-mean1)*(np.mat(sw).I) return w def plot(df,w): dataMat = np.array(df[['density', 'ratio_sugar']].values[:, :]) labelMat = np.mat(df['label'].values[:]).transpose() m = np.shape(dataMat)[0] xcord1 = [] ycord1 = [] xcord2 = [] ycord2 = [] for i in range(m): if labelMat[i] == 1: xcord1.append(dataMat[i, 0]) ycord1.append(dataMat[i, 1]) else: xcord2.append(dataMat[i, 0]) ycord2.append(dataMat[i, 1]) plt.figure(1) ax = plt.subplot(111) ax.scatter(xcord1, ycord1, s=30, c='red', marker='s') ax.scatter(xcord2, ycord2, s=30, c='green') x = np.arange(-0.2, 1.0, 0.1) y = np.array((-w[0, 0] * x) / w[0, 1]) plt.sca(ax) plt.plot(x, y) plt.xlabel('density') plt.ylabel('ratio_sugar') plt.title('LDA') plt.show() if __name__ == '__main__': df = LoadDataSet() w = calulateW(df) plot(df,w)
获得的分类图形如下: