一、knn的原理
knn是k-Nearest Neighbor的简称,是常用分类算法之一。当预测一个新的值x的时候,根据它距离最近的K个点是什么类别来判断x属于哪个类别。可以概括为近邻算法是采用测量不同特征值之间得距离得方法进行分类。
二、knn的步骤
1、计算测试样本与已知类别的训练样本的距离;
2.按照距离由近到远进行排序;
3.为测试样本选择k个与其距离最小的训练样本(一般k值取取奇数);
4.确定前k个样本所在类别的出现频率;
5.返回前k个样本出现频率最高的类别作为当前点的预测分类。
三、实践
我们通过python实现knn区分类别,主要包含创建训练样本集、knn算法实现、主函数三部分。
1、创建训练样本集和类别标签
def create_data_set():
#创建包含两个特征的训练样本-二维数组
training_data_set = np.array([[1.0,0.9],[1.1,1.0],[0.5,0.5],[0.2,0.5]])
#样本对应的类别
labels = ["优","优","良","差"]
return training_data_set,labels
2.knn算法实现
def knn_classification(target_data_set,training_data_set,labels,k):
"""
knn算法实现
:param target_data_set:待分类样本集
:param training_data_set:训练样本集
:param labels:与训练样本集对应的类别标签列表
:param k:
:return:
"""
#step1:计算待分类样本与训练样本集每个元素的欧式距离
#使用tile()将待分类样本进行复制训练数据的行数,并与训练样本集作差
diff = np.tile(target_data_set,(4,1))-training_data_set
#将差值样本在行方向(axis = 1,0是列方向)进行求平方和
squred_diff_sum = (diff**2).sum(axis=1)
distance = squred_diff_sum**0.5
#step2.按照距离由近到远进行排序
#argsort 用来将列表中的元素进行从小到大排列,返回的是一串索引
sorted_distance = distance.argsort()
#step3.根据排序结果选择k个近邻,并统计各个类别出现的次数
class_count_dict = {}
#按照排好顺序进行访问标签列表,并使用get()统计
for i in range(k):
label = labels[sorted_distance[i]]
#如存在标签,则直接相加,如不存在设置默认值为0
class_count_dict[label] =class_count_dict.get(label,0)+1
#step4.返回出现次数最多的类别标签
#按照字典的values(itemgetter(1))进行倒叙(reverse = True,默认是正序排列)排列
sorted_class_count = sorted(class_count_dict.items(),key = operator.itemgetter(1),reverse= True)
return sorted_class_count[0][0]
3.主函数
def main():
training_data_set,labels =create_data_set()
k = 3
target_data_set0 = np.array([1.0,1.0])
output_label0 = knn_classification(target_data_set0,training_data_set,labels,k)
print("待分类样本",target_data_set0,"\n类别为:",output_label0,"\n")
target_data_set1 = np.array([0.4, 0.5])
output_label1 = knn_classification(target_data_set1, training_data_set, labels, k)
print("待分类样本", target_data_set1, "\n类别为:", output_label1, "\n")
target_data_set2 = np.array([0.0, 0.0])
output_label2 = knn_classification(target_data_set2, training_data_set, labels, k)
print("待分类样本", target_data_set2, "\n类别为:", output_label2, "\n")
4.运行结果