本文内容为《Python大战机器学习》参考书第一章线性模型的部分学习笔记
简单记忆:数据降维方式的一种,最常用的数据降维方式是PCA(主成分分析)
数据集使用的是鸢尾花数据
from sklearm.datasets import load_iris
df = load_iris()
print(df.DESCR)
将数据集拆分为训练集和测试集
from sklearn.model_selection import train_test_split
X_train, X_test,y_train,y_test = train_test_split(df.data,df.target,test_size=0.25,random_state=0,stratify=df.target)
stratify 分层 stratified fashion 分层方式 因为鸢尾花数据集不是随机分布,而是从上到下按同一个类别在一起排列的,所以拆分数据的时候需要指定stratify这个参数
使用help(train_test_split)查看帮助文档 运行帮助文档中的例子
import numpy as np
from sklearn.model_selection import train_test_split
X, y = np.arange(10).reshape((5, 2)), range(5)
X
array([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
list(y)
[0, 1, 2, 3, 4]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42)
...
X_train
array([[4, 5],
[0, 1],
[6, 7]])
y_train
[2, 0, 3]
X_test
array([[2, 3],
[8, 9]])
y_test
[1, 4]
train_test_split(y, shuffle=False)
[[0, 1, 2], [3, 4]]
引入线性判别模型、拟合、预测
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
LDA = LinearDiscriminantAnalysis()
LDA.fit(X_train,y_train)
Out[18]:
LinearDiscriminantAnalysis(n_components=None, priors=None, shrinkage=None,
solver='svd', store_covariance=False, tol=0.0001)
LDA.score(X_train,y_train)
Out[19]: 0.9732142857142857
LDA.predict(X_test)
Out[20]:
array([0, 0, 0, 0, 1, 1, 1, 0, 1, 2, 2, 2, 1, 2, 1, 0, 0, 2, 0, 1, 2, 1, 1,
0, 2, 0, 0, 1, 2, 1, 0, 1, 2, 2, 0, 1, 2, 2])
LDA.score(X_test,y_test)
Out[21]: 1.0
dir(LDA)
画图对拟合后的结果进行展示
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
X = np.vstack((X_train,X_test))
Y = np.vstack((y_train.reshape(y_train.size,1),\
y_test.reshape(y_test.size,1)))
X.shape
Y.shape
converted_X = np.dot(X,np.transpose(LDA.coef_)) + \
LDA.intercept_
fig = plt.figure()
ax = Axes3D(fig)
colors = 'rgb'
markers = 'o*s'
for target,color,marker in zip([0,1,2],colors,markers):
pos = (Y == target).ravel()
X = converted_X[pos,:]
ax.scatter(X[:,0],X[:,1],X[:,2],\
color=color,marker=marker,\
label = "Label%d"%target)
ax.legend(loc="best")
fig.suptitle("Iris After LDA")
plt.show()
从上图可以看出经过判别分析后,三个品种的鸢尾花能够很好的区别开
numpy中的函数需要进一步掌握
np.vstack()
np.dot() zip()