KNN:K最近邻算法,K-Nearest Neighbor,是早期的基于统计的有监督分类方法。
举个例子:小明有10个朋友,这10个朋友的学习成绩有好有差(“好、中、差”),小明跟这10个朋友的关系也有亲疏;现在已知小明最好的5个朋友中有4个都是成绩好的,那么可以推断小明的成绩也是好的。
KNN算法的思想大致就是这么个意思,即“近朱者赤,近墨者黑”。
- Training Data:10个朋友的记录(可能有姓名、)
- K=5
一、KNN算法基本步骤
- 计算待分类特征数据到每个训练数据的距离(这里选择Euclidean Distance)
- 对(1)的距离进行升序排序
- 选取(2)排序后排名前K的距离(邻居)对应的训练数据及其类别标签
- K个邻居进行投票,票数最多的类别标签就是待分类特征数据的类别
二、iris数据集的KNN算法python实现
- 数据集:scikit-learn库预置的“toy”数据集iris。
#导入scikit-learn库预置的“toy”数据集iris
from sklearn import datasets
iris=datasets.load_iris()
iris_x = iris.data #iris特征数据
iris_y = iris.target #iris类别标签
#KNN算法函数
def knn(train_data, label, test_data, k):
##检验输入数据是否合规
'''if not train_data: #case1:无训练数据
raise ValueError("Invalid training data.")
elif not test_data: #case2:无效待分类数据
raise ValueError("Invalid test data.")
elif len(test_data)!=len(train_data[0]):
raise ValueError("Invalid test data.")'''
##计算所有距离(欧式)
#distances=distance(train_data,test_data)
distances=distance_iter(train_data,test_data) #[(0,d1),(1,d2),...]
##距离排序
sorted_distances=sorted(distances, key=lambda x:x[1])
##截取前K个邻居
k_neighbors=sorted_distances[:k]
##K个邻居投票
vote_result=vote(k_neighbors,label)
#返回分类结果
return vote_result
## for循环实现的距离计算
def distance(dataset, record):
#euclidean distance
#使用循环,数据量较大时存在执行效率低和溢出的问题
distances=[]
for data in dataset:
sqr_distance=0
for index in range(len(record)):
sqr_distance+=(data[index]-record[index])**2
distances.append(sqr_distance**0.5)
return list(enumerate(distances)) #生成带索引的列表
def vote(neighbors,label):
neighbors_labels=[label[index] for index,data in neighbors]
return max(neighbors_labels,key=neighbors_labels.count)
## iteration实现的距离计算
def distance_iter(dataset,record):
#euclidean distance
#使用生成器
return list(enumerate(distance_iter_generator(dataset,record)))
def distance_iter_euclidean(data1,data2):
sqr_sum=0
for i in range(len(data1)):
sqr_sum+=(data1[i]-data2[i])**2
return sqr_sum**0.5
def distance_iter_generator(dataset,record):
for data1 in dataset:
yield distance_iter_euclidean(data1,record)
#print(iris_x)
test1=[4.7,3.2,1.3,0.2] #label=0
test2=[5.9,3,5.1,1.8] #label=1
print(knn(iris_x,iris_y,test1,5)==0)
print(knn(iris_x,iris_y,test2,5)==1)
>>True
True