Python实现线性判别,参考周志华机器学习第三章
import numpy as np
import matplotlib.pyplot as plt
mat=np.loadtxt('gua.txt')
X1=mat[0:8,1:3] #好瓜的数据
X0=mat[8:,1:3] #坏瓜的数据
mean1=np.mean(X1, axis=0) #计算每一列的平均值
mean0=np.mean(X0,axis=0)
mean1=mean1.reshape(2,1)
mean0=mean0.reshape(2,1)
cov0 = np.cov(X0,rowvar=False) #坏瓜协方差矩阵
cov1 = np.cov(X1,rowvar=False) #好瓜协方差矩阵
Sw=cov0+cov1
w=np.dot(Sw**(-1),mean1-mean0)
print(w)
plt.scatter(X0[:,0],X0[:,1],label='0')
plt.scatter(X1[:,0],X1[:,1],label='1')
plt.plot([0,1],[0,-w[0]/w[1]],label='y')
plt.show()