算法主要思想
- 已知多个样本的分类情况(假设如下图所示)。其中有两个分类,可以简记为红类/蓝类。而黑点则是带判断的目标点。
- K代表与目标点(黑点)距离最近的K个点。
- 计算黑点与K个点之间的距离(这里可以采用欧式距离)
- 按照少数服从多数的规则,确定目标点的类别。例如,距离黑点最近的5个点当中,红点占3个,蓝点占2个,则该黑点属于红类。
该算法的思路较为简单粗暴。缺点如下所示:
python实现
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
# 图形展示
x1 = np.array([1,2,3])
y1 = np.array([88,99,100])
x2 = np.array([10,11,12])
y2 = np.array([3.3,4,5.5])
x_test = np.array([6.])
y_test = np.array([4.])
scatter1 = plt.scatter(x1, y1)
scatter2 = plt.scatter(x2, y2)
scatter3 = plt.scatter(x_test, y_test)
plt.legend(handles = [scatter1, scatter2, scatter3], labels=["label1", 'label0', 'X'], loc='best')
plt.show()
x_data1 = np.array([[x,y] for x, y in zip(x1, y1)])
x_data2 = np.array([[x,y] for x, y in zip(x2, y2)])
x_data = np.concatenate((x_data1, x_data2), axis=0)
y_data = np.array([1,1,1,0,0,0])
x_test = np.array([6, 4])
k=3
result =np.sqrt(np.sum((x_test - x_data)**2., axis=1))
k_idx = list(result.argsort())[0:k] # 获得最小的几个数字的索引
print("clsss_"+str(Counter(y_data[tuple([k_idx])]).most_common(1)[0][0]))