KNN算法分析实战(鸢尾花数据集)

KNN算法分析实战(鸢尾花数据集)
目录

KNN算法分析实战(鸢尾花数据集)

代码效果图

一、导入需要的包

二、

1.导入数据

2.建立训练集和测试集

3.设置K值

4. 十重交叉验证K值

5.模型拟合 

6.数据可视化输出

提示:以下是本篇文章正文内容,下面案例可供参考

一、导入需要的包
要是报错的话可以在pycharm安装包,要是不行就在命令窗口输入pip install +包名

import matplotlib.pyplot as plt
from sklearn import neighbors
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn import model_selection
from sklearn import metrics
二、
1.导入数据
导入数据并查看前5行代码

df1 = pd.read_csv(r'D:\python\iris.csv')
print(df1.head())#输出前五行
predictors = df1.columns[:-1]
 

2.建立训练集和测试集
代码如下:

x_train,x_test,y_train,y_test=model_selection.train_test_split(
    df1[predictors],df1.Species,
    test_size=0.5,
    random_state = 1234
)
print(np.ceil(np.log2(df1.shape[0])))
3.设置K值
#设置待测试的不同K值
K = np.arange(1,np.ceil(np.log2(df1.shape[0])))
print(np.arange(1,np.ceil(np.log2(df1.shape[0]))))
#设置空列表,用于储存平均准确率
accuracy = []
4. 十重交叉验证K值
使用十重交叉验证K值,并做出最适合K值的折线图

#使用十重交叉验证的方法
for k in K:
    cv_result = model_selection.cross_val_score\
        (neighbors.KNeighborsClassifier(n_neighbors=int(k),
                                        weights='distance'),
         x_train, y_train, cv=10, scoring='accuracy')
    accuracy.append(cv_result.mean())
 
#从K个平均准确率中挑选出最大值做对应的目标
arg_max = np.array(accuracy).argmax()
#中文负号正常显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
#绘制不同k值与准确率之间的折线图
plt.plot(K,accuracy)
plt.scatter(K,accuracy)
plt.text(K[arg_max],accuracy[arg_max],'最佳K值为%s'%int(K[arg_max]))
plt.show()
 

5.模型拟合 
代入K值,进行模型拟合

#重新构建模型,并将最佳邻近数个数设置为7
knn_class = neighbors.KNeighborsClassifier(n_neighbors=7,weights='distance')
#模型拟合
knn_class.fit(x_train,y_train)
#模型在测试集上的预测
predict = knn_class.predict(x_test)
6.数据可视化输出
#构建混淆矩阵
cm = pd.crosstab(predict,y_test)
print(f'鸢尾花种类混淆矩阵\n{cm}')
#热力图输出
cm = pd.DataFrame(cm,columns=['setosa','versicolor','virginica'],
                  index=['setosa','versicolor','virginica'])
sns.heatmap(cm,annot=True,cmap='GnBu')
plt.xlabel('Real Lable')
plt.ylabel('Predict Lable')
plt.title('鸢尾花种类热力图')
plt.show()
#显示各类预测准确率
b = metrics.classification_report(y_test,predict)
print(f'显示各类预测准确率\n{b}')

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值