1.算法原理
(1)计算距离(欧式距离)
(2)升序排列(从小到大)
(3)取前K个
(4)加权平均
注意:K过大,分类结果模糊;K过小,受个别案例影响,波动大。
2.代码实现demo(python)
(1)以下代码实现找去前五个训练集,并按升序排列
import random
import csv
#读取
with open('Prostate_Cancer.csv','r') as file:
reader=csv.DictReader(file)
dataset=[row for row in reader]
print(dataset)
#分组
random.shuffle(dataset) #shuffle洗牌,将数据打乱
n=len(dataset)//3 #双斜杆是整除
test_set=dataset[0:n]#测试集
train_set=dataset[n:] #训练集 从数据的1/3到最后一个数据
#KNN 先算距离,
def distance(d1,d2):
res=0
for key in("radius","texture","perimeter","area","smoothness","compactness","symmetry","fractal_dimension"):
res+=(float(d1[key])-float(d2[key]))**2
return res**0.5
K=5
def knn(data):
res=[
{"result":train['diagnosis_result'],"distance":distance(data,train)}
for train in train_set
]
#print(res)
#2.排序
res=sorted(res,key=lambda item:item['distance'])
#print(res)
#3.取前k个
res2=res[0:K]
#print(res2)
#4.取加权平均
result={'B':0,'M':0}
#离的近的权重高,离得远的权重低
sum=0
for r in res2:
sum+=r['distance']
for r in res2:
result[r['result']]+=1-r['distance']/sum
print(result)
print(data['diagnosis_result'])
knn(test_set[0])
(2)测准确率
import random
import csv
#读取
with open('Prostate_Cancer.csv','r') as file:
reader=csv.DictReader(file)
dataset=[row for row in reader]
#print(dataset)
#分组
random.shuffle(dataset) #shuffle洗牌,将数据打乱
n=len(dataset)//3 #双斜杆是整除
test_set=dataset[0:n]#测试集
train_set=dataset[n:] #训练集 从数据的1/3到最后一个数据
#KNN 先算距离,
def distance(d1,d2):
res=0
for key in("radius","texture","perimeter","area","smoothness","compactness","symmetry","fractal_dimension"):
res+=(float(d1[key])-float(d2[key]))**2
return res**0.5
K=5
def knn(data):
res=[
{"result":train['diagnosis_result'],"distance":distance(data,train)}
for train in train_set
]
#print(res)
#2.排序
res=sorted(res,key=lambda item:item['distance'])
#print(res)
#3.取前k个
res2=res[0:K]
#print(res2)
#4.取加权平均
result={'B':0,'M':0}
#离的近的权重高,离得远的权重低
sum=0
for r in res2:
sum+=r['distance']
for r in res2:
result[r['result']]+=1-r['distance']/sum
if result['B']>result['M']:
return 'B'
else:
return 'M'
knn(test_set[0])
correct=0
for test in test_set:
result=test['diagnosis_result']
result2=knn(test)
if result==result2:
correct+=1
print(correct)
print(len(test_set))
3.CSV的数据集,有需要的私我!