1 介绍
使用sklearn中的LDA处理鸢尾花数据集,理论知识详见西瓜书或者sklearn文档即可,只对代码进行说明。LDA在降维的时候需要注意维度要小于原有维度,或者N-1,N表示类别的数目
2 导入模块
import pandas as pd # 导入pandas库
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
3 导入数据集
path = "iris.data"
df = pd.read_csv(path, header=None)
4 划分训练集和测试集
df = shuffle(df)
r = int(len(df) * 0.7)
train_data = df[:r]
test_data = df[r:]
train_x = train_data.loc[:, :3]
train_y = train_data.loc[:, 4]
test_x = test_data.loc[:, :3]
test_y = test_data.loc[:, 4]
5 训练和预测
n_components = 1 # 修改不同的维度
a_list = [] # 用来存放准确率
clf = LinearDiscriminantAnalysis(solver='svd', n_components=n_components)
clf.fit(train_x, train_y)
test_pred = clf.predict(test_x)
a = accuracy_score(test_y, test_pred)
a_list.append(a)
6 绘图
figure = plt.figure()
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.plot(list(range(num)), a_list)
plt.title("Iris数据-Fisher判别分析-数据随机打乱-维度{}-准确率图像".format(n_components))
plt.show()
详细代码关注:AI学习部公众号,发送“机器学习”关键词,获取数据集和详细代码