在这里就不说明LDA的原理了,不懂的同学可以百度找相关资料学。这里直接给出楼主的python实现,以及搜索到的其他实现。
楼主实现:
#实现线性判别分析算法
#二类n维降至一维
#传入的data,target都是array,target分别是0和1
def lda(data,target):
target=target.flatten() #将target变成意味数组
print(target.shape)
df1=data[target==0] #第0类
df2=data[target==1] #第1类
n=data.shape[1]
u1=df1.mean(0).reshape((1,n))
u2=df2.mean(0).reshape((1,n))
print(u2.shape)
data_mean_1=df1-u1 #应用numpy的广播机制
data_mean_2=df2-u2
print(data_mean_1.T.dot(data_mean_1))
Sw=data_mean_1.T.dot(data_mean_1)+data_mean_2.T.dot(data_mean_2)
w=np.mat((u1-u2))*np.mat(Sw).I
return w
#多类别降至K维的实现
def lda_muliti_class(data,target,K):
#within_class scatter matrix
clusters=unique(target)
if K>len(clusters)-1:
print("K is too much")
print("please input again")
exit(0)
Sw=np