几分钟写了个KNN Python代码,在编译器上可以直接跑:
"""
programs: KNN algorithm
description:
1.calculate the distance between test data and every single train data
2.sort the distance
3.select the minimum k points by distance
4.count the label frequency of k points
5.return to the label of the highest frequency
"""
from mlxtend.data import iris_data
import numpy as np
class knn_csy(object):
def __init__(self,dataset,label):
self.dataset=dataset
self.label=label
def distance(self,dataset_i,testdata):
dist=np.sum((dataset_i-testdata)**2)
return np.sqrt(dist)
def calculate_dis(self,testdata,k=10,updateflage=0):
"""
:param testdata:
:param k: default by 10
:param updateflage:
:return:
"""
if len(testdata)!=len(self.dataset[0]):
raise Exception("wrong input array of testdata");
dis=[]
dimension=len(self.dataset)
for i in range(dimension):
distance=self.distance(self.dataset[i],testdata)
dis.append(distance)
dic=zip(dis,self.label)
dic=sorted(dic)
label=[]
for i in range(k):
label.append(dic[i][1])
count=np.bincount(label)
label=np.argmax(count)
if updateflage:
self.dataset.append(testdata)
self.label.append(label)
return label
if __name__ == '__main__':
dataset,label=iris_data()
myknn=knn_csy(dataset,label)
testdata=[2,1,1,2]
label=myknn.calculate_dis(testdata,3)
print label