目录
1、KNN算法概述
KNN可以说是最简单的分类算法之一,同时,它也是最常用的分类算法之一,注意KNN算法是有监督学习中的分类算法,可以参考一下本篇博客。还有一种是机器学习算法Kmeans,它是无监督学习算法(这里不做过多介绍,可以参考一下本篇博客。
2、什么是knn算法
通俗的说就是根据与要调查的数据相似的数据,来判断要调查的数据的属性和分类。这就好比老”物以类聚,人以群分“,要想了解一个人,可以从他的朋友了解。这就像数学中的找规律一
样,给你n组相似的数据,通过算法来判断某n个或n组数据是什么。
其实啊,KNN的原理就是当预测一个新的值x的时候,根据它距离最近的K个点是什么类别来判断x属于哪个类别。听起来有点绕,还是看看图吧。
图中绿色的点就是我们要预测的那个点,假设K=3。那么KNN算法就会找到与它距离最近的三个点(这里用圆圈把它圈起来了),看看哪种类别多一些,比如这个例子中是蓝色三角形多一些,新来的绿色点就归类到蓝三角了。
但是,当K=5的时候,判定就变成不一样了。这次变成红圆多一些,所以新来的绿点被归类成红圆。
3、算法原理解析
3.1、通用步骤
- 计算距离(常用欧几里得距离(推荐使用较为简单)或者马氏距离,不理解的可以看一看俩者的比较)
- 升序排列:按照距离的远近进行排列(用python内置函数进行排序)
- 取前k个:在所得的距离中取前k个距离(k任意)数据
- 加权平均值:对取得的前k个距离数据进行加权平均(原因:前k个距离可能与目标数据的差距都不一样,所以要使用加权平均,而不使用算术平均,这也减小了误差)
3.2、k的取值对预测结果的影响
- k太大:导致分类模糊
- k太小:受个别样例影响,波动较大,误差较大
k值和误差的关系
3.3、k的选取
- 根据经验(不建议初学者)
- 多尝试几次,找到比较靠谱的k值
- 均方差误差
均方差示例图
4、实战应用 (癌症检测数据)
完整代码:
import random
import csv
#打开文件
with open('E:\maker\培训内容\鸿蒙\相关\阶段五\相关数据集\prostate-cancer\Prostate_Cancer.csv','r') as file:
#读取文件
reader=csv.DictReader(file)
datas=[row for row in reader]
#分组
random.shuffle(datas)#打乱数据的顺序
n=len(datas)//3#得到数据的长度整除3,分成三份
test_set=datas[0:n]#1/3为测试
train_set=datas[n:]#2/3为训练
#KNN
#距离
def distance(d1,d2):
#求和
sum=0
#将计算用的数据放在一个元组中
for key in("radius","texture","perimeter","area","smoothness","compactness","symmetry","fractal_dimension"):
#转化为浮点型数据
sum+=(float(d1[key])-float(d2[key]))**2
#平方后在开方
return sum**0.5
K=9
def knncount(data):
#1、距离
sum=[
{"result":train['diagnosis_result'],"distance":distance(data,train)}
for train in train_set #解释一下,这里为什么要用train_set,因为在这个示例中train_set是已知的,我们要根据已知来判断test_set中的数据的结果,寻找data和train_set的关系
]
#2、排序---升序
sum=sorted(sum,key=lambda item:item['distance'])
#3、取前k个
sum1=sum[0:K]
#4、加权平均
result={'B':0,'M':0}#0代表权重
#总距离求和
s=0
for r in sum1:
s+=r['distance']
#判断是给B加权还是给M加权
for r in sum1:
result[r['result']]+=1-r['distance']/s #result[r['result']]表示,result[B]还是result[M]给谁加权
#输出预测结果
if result['B']>result['M']:
return 'B'
else:
return 'M'
#测试
correct=0
for test in test_set:
result=test['diagnosis_result']#真实结果
result1=knncount(test)#测试结果
#如果真实结果和预测结果相同那么正确个数加1
if result==result1:
correct+=1
print("准确率:{:.2f}%".format(100*correct/len(test_set)))
运行结果(部分截图):
本例中的数据的网盘地址为:百度网盘 请输入提取码
提取码:zxmt
数据集部分截图:
学习视频网址:【智源学院】30分钟KNN算法-有意思专题系列(K-Nearest Neighbor, KNN)_哔哩哔哩_bilibili