什么是KNN算法
(41条消息) K-近邻算法(KNN)_<Running Snail>的博客-CSDN博客_k近邻算法
K近邻(K-Nearest Neighbor, KNN)是一种最经典和最简单的有监督学习方法之一。K-近邻算法是最简单的分类器,没有显式的学习过程或训练过程,是懒惰学习(Lazy Learning)。当对数据的分布只有很少或者没有任何先验知识时,K 近邻算法是一个不错的选择。
K近邻算法既能够用来解决分类问题,也能够用来解决回归问题。该方法有着非常简单的原理:当对测试样本进行分类时,首先通过扫描训练样本集,找到与该测试样本最相似的个训练样本,根据这个样本的类别进行投票确定测试样本的类别。也可以通过个样本与测试样本的相似程度进行加权投票。如果需要以测试样本对应每类的概率的形式输出,可以通过个样本中不同类别的样本数量分布来进行估计。
举个例子:
图中绿色的点就是我们要预测的那个点,假设K=3。那么KNN算法就会找到与它距离 最近的三个点(这里用圆圈把它圈起来了),看看哪种类别多一些,比如这个例子中 是蓝色三角形多一些,新来的绿色点就归类到蓝三角了
但是,当K=5的时候,判定就变成不一样了。这次变成红圆多一些,所以新来的绿 点被归类成红圆。从这个例子中,我们就能看得出K的取值是很重要的。
KNN实现步骤
1.计算距离
通常使用欧几里得距离或者马氏距离
欧几里得距离:p =
马氏距离: d = (马氏距离(Mahalanobis Distance) - 知乎 (zhihu.com))
2.升序排列
3.去前K个
K太大:导致分类模糊
K太小:受个例影响,波动较大
- 一般k值较小。
- k通常取奇数,避免产生相等占比的情况。
- 往往需要通过**交叉验证(Cross Validation)**等方法评估模型在不同取值下的性能,进而确定具体问题的K值。
4.加权平均
KNN实例
利用KNN算法求病人癌症检测的正确率
使用数据集:
代码部分:
import csv
import random
# 读取数据
with open("E:\Prostate_Cancer.csv","r") as f:
render = csv.DictReader(f)
datas = [row for row in render]
# 分组,打乱数据
random.shuffle(datas)
n = len(datas)//3
test_data = datas[0:n]
train_data = datas[n:]
# print (train_data[0])
# print (train_data[0]["id"])
# 计算对应的距离
def distance(x, y):
res = 0
for k in ("radius","texture","perimeter","area","smoothness","compactness","symmetry","fractal_dimension"):
res += (float(x[k]) - float(y[k]))**2
return res ** 0.5
# K=6
def knn(data,K):
# 1. 计算距离
res = [
{"result":train["diagnosis_result"],"distance":distance(data,train)}
for train in train_data
]
# 2. 排序
sorted(res,key=lambda x:x["distance"])
# print(res)
# 3. 取前K个
res2 = res[0:K]
# 4. 加权平均
result = {"B":0,"M":0}
# 4.1 总距离
sum = 0
for r in res2:
sum += r["distance"]
# 4.2 计算权重
for r in res2 :
result[r['result']] += 1-r["distance"]/sum
# 4.3 得出结果
if result['B'] > result['M']:
return "B"
else:
return "M"
# print(distance(train_data[0],train_data[1]))
# 预测结果和真实结果对比,计算准确率
for k in range(1,11):
correct = 0
for test in test_data:
result = test["diagnosis_result"]
result2 = knn(test,k)
if result == result2:
correct += 1
print("k="+str(k)+"时,准确率{:.2f}%".format(100*correct/len(test_data)))
运行结果:
当k=1时,正确率在54%,当k>5时,正确率回归为54%,印证了我们之前的结论:k的取值要适度,过大过小都不行。