KNN算法

KNN算法:近朱者赤,近墨者黑

最近邻算法(K-NearestNeighbor),简称KNN算法
算法原理:找到K个与新数据最近的样本,取样本中最多的一个类别作为新数据的类别
算法优点:简单易实现;对边界不规则的数据效果较好
算法缺点:只适合小数据集;数据不平衡效果不好;必须要做数据标准化;不适合特征维度太多的数据
关于K的选取
会影响到模型的效果
K越小,易过拟合
K越大,易欠拟合
实例体会

from sklearn import datasets  #sklearn数据集
from sklearn.neighbors import KNeighborsClassifier #sklearn模块的KNN类
import numpy as np #矩阵运算库numpy

np.random.seed(0)
#设置随机种子,不设置的话默认是按系统时间作为参数,设置后可以保证我们每次产生的随机数是一样的
iris = datasets.load_iris()#获取鸢尾花数据集
iris_x = iris.data #数据部分
iris_y = iris.target #类别部分
print(iris_x)
#鸢尾花数据集主要包含了鸢尾花的花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性(特征)
print(iris_y)
#鸢尾花卉属于『Setosa,Versicolour,Virginica』三个种类中的哪一类
#每个类别有50条数据

#从150条数据中选140条作为训练集,10条作为测试集。permutation接收一个数作为参数(这里为数据集长度150),产生一个0-149乱序一维数组
randommarr = np.random.permutation(len(iris_x))
iris_x_train = iris_x[randommarr[:-10]] #训练集数据
print(iris_x_train)
'''
 [6.  2.2 4.  1. ]
 [5.5 4.2 1.4 0.2]
 [7.3 2.9 6.3 1.8]
 [5.  3.4 1.5 0.2]
 [6.3 3.3 6.  2.5]
 [5.  3.5 1.3 0.3]
 [6.7 3.1 4.7 1.5]
 [6.8 2.8 4.8 1.4]
 [6.1 2.8 4.  1.3]
 [6.1 2.6 5.6 1.4]
 [6.4 3.2 4.5 1.5]
 [6.1 2.8 4.7 1.2]
 ...
 '''
 
iris_y_train = iris_y[randommarr[:-10]] #训练集标签
print(iris_y_train)
'''2 1 0 2 0 2 0 1 1 1 2...'''
iris_x_test = iris_x[randommarr[-10:]] #测试集数据
print(iris_x_test)
'''
 [5.6 3.  4.1 1.3]
 [5.9 3.2 4.8 1.8]
 [6.3 2.3 4.4 1.3]
 [5.5 3.5 1.3 0.2]
 [5.1 3.7 1.5 0.4]
 [4.9 3.1 1.5 0.1]
 [6.3 2.9 5.6 1.8]
 [5.8 2.7 4.1 1. ]
 [7.7 3.8 6.7 2.2]
 [4.6 3.2 1.4 0.2]
 '''
iris_y_test = iris_y[randommarr[-10:]] #测试集标签
print(iris_y_test)
'''1 1 1 0 0 0 2 1 2 0'''

#定义一个KNN分类器对象
knn = KNeighborsClassifier()
#调用该对象的训练方法,主要接收两个参数:训练数据集及其类别标签
knn.fit(iris_x_train, iris_y_train)

#调用预测方法,主要接受一个参数:测试数据集
iris_y_predict = knn.predict(iris_x_test)

#计算各测试样本预测的概率值,这里我们没有用概率值,但是在实际工作中可能会参考概率值来进行最后结果的筛选,而不是直接使用给出的预测标签
probility = knn.predict_proba(iris_x_test)

#选出最优的K值
#计算与最后一个测试样本距离最近的5个点,返回的是这些样本的序号组成的数组
neighborpoint = knn.kneighbors([iris_x_test[-1]],5)

#调用该对象的打分方法,计算出准确率
score = knn.score(iris_x_test,iris_y_test,sample_weight=None)

#输出测试结果
print('iris_y_predict=')
print(iris_y_predict)
iris_y_predict=
#[1 2 1 0 0 0 2 1 2 0]

#输出原始测试数据集的正确标签,以方便对比
print('iris_y_test=')
print(iris_y_test)
iris_y_test=
#[1 1 1 0 0 0 2 1 2 0]

#输出准确率计算结果
print('Accuracy:',score)
#Accuracy: 0.9

理解程度:50%
模范代表

  • 28
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值