KNN算法概述
KNN(k-nearest neighbor)算法属于机器学习中的有监督分类算法,主要用于分类,是最简单的机器学习算法之一顾名思义,其算法主体思想就是根据距离相近的邻居类别,来判定自己的所属类别。
KNN算法思路
1、计算测试对象与训练集中所有对象的距离,一般采用欧式距离。
2、找出与计算对象距离最近的K个对象,作为测试对象的邻居;
3、找出这K个对象中出现频率最高的类别,该类别即为测试对象的所属类别。
一个简单的例子
在二维坐标轴中,有四个点(即训练集),分别是a1(1,1),a2(1,2),b1(3,3),b2(3,4)。其中a1,a2为A类,b1,b2为B类。
现在,有一个新的点c(2,1)(即测试对象),我们想要判断这个点属于A类还是B类。
此时我们可以采用KNN算法进行求解
1、计算测试对象与训练集中所有对象的距离,即点c与a1,a2,b1,b2的距离
2、将计算出来的距离进行升序排序
序号 | 点标签 | 点类别 | 与c点距离 |
---|---|---|---|
1 | a1 | A | 1.0 |
2 | a2 | A | 1.4 |
3 | b1 | B | 2.2 |
4 | b2 | B | 3.1 |
3、找出与计算对象距离最近的K个对象
一般情况下K的值取3(即取K=3)
序号 | 点标签 | 点类别 | 与c点距离 |
---|---|---|---|
1 | a1 | A | 1.0 |
2 | a2 | A | 1.4 |
3 | b1 | B | 2.2 |
4、找出这K个对象中出现频率最高的类别,测试对象便属于该类别
由上表可知,在所取的K对象中,f(A)=2/3,f(B)=1/3,频率最高的为A类,即点c属于A类。
手写数字识别系统
下面我们将构造使用K邻近分类器(即KNN算法)搭建的手写数字识别系统,并用测试集测试分类正确率
1、准备数据
首先下载手写数字数据集,数据集分为训练集(trainingDigits)和测试集(testDigits),训练集中有100个样本,测试集中有50个样本,每个样本的格式如下图所示
2、载入数据
为了便于计算图像与图像间的欧式距离,我们需要将图像格式化为一个向量。由于图像为32x32的二进制图像,所以我们需要将之转化为1x1024的向量。为了便于操作,我们首先编写Load_Data函数,用于读取指定目录下的所有样本,并将对每个样本进行向量化操作。
def Load_Data(file_dir):
data = [] # 数据列表
label = [] # 标签列表
file_list = os.listdir(file_dir) # 当前目录下所有样本列表
for name in file_list: # 读取当前目录下所有文件并转化成行向量
vector = np.zeros((1, 1024))
file_name = file_dir + "/" + name # 样本路径
f = open(file_name)
for i in range(32): # 转化成行向量
line = f.readline()
for j in range(32):
vector[0, i*32 + j] = int(line[j])
data.append(vector) # 数据加入数据集
label.append([int(name.split('_')[0])]) # 标签加入标签集合
return np.array(data), np.array(label)
加载的训练数据结果如下
>>> train_data, train_label = Load_Data(trian_file_dir)
>>> train_data
array([[[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.]],
...,
[[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.]]])
>>> train_label.T
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,
6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8,
8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]])
3、KNN算法的实现
由上面提到的KNN算法思路
①计算距离
②找出K个距离最近的对象
③取K个对象中类别出现频率最大一类
可写出以下分类代码
def KNN(inX, data, label, k):
"""
:param inX: 测试对象
:param data: 训练数据
:param label: 训练数据的标签
:param k: 取前k个最小值
:return:
"""
# 计算输入向量与其他样本之间的距离
squre = (data - inX)**2
SumFunction = lambda x: x.sum()
distance_squre = np.apply_along_axis(SumFunction, 2, squre)
distance = np.sqrt(distance_squre)
# 计算前K个最小值频率最高的标签
sorted_index = np.argsort(distance.T)
class_count = {}
for i in range(k):
votelabel = label[sorted_index][0][i][0]
class_count[votelabel] = class_count.get(votelabel, 0) + 1
sorted_class_count = sorted(class_count.items(), key=lambda x: x[1], reverse=True)
return sorted_class_count[0][0]
4、进行手写识别系统的测试
加载训练集和测试集中的所有样本,对手写识别系统进行测试,将测试集中的每一个样本的分类结果与它的原始结果作对比,从而得到此次手写识别系统的正确率。
def HandWritingTest():
trian_file_dir = "../手写识别数据/trainingDigits"
test_file_dir = "../手写识别数据/testDigits"
train_data, train_label = Load_Data(trian_file_dir) # 训练集
test_data, test_label = Load_Data(test_file_dir) # 测试集
count = 0
for i in range(len(test_data)):
prdict = KNN(test_data[i], train_data, train_label, 3)
print("The predict label is->%s, the true label is->%s" % (prdict, test_label[i][0]))
if prdict == test_label[i][0]:
count += 1
print("The accuracy is ", count / len(test_data))
测试结果如下
The predict label is->0, the true label is->0
The predict label is->0, the true label is->0
The predict label is->0, the true label is->0
The predict label is->0, the true label is->0
The predict label is->0, the true label is->0
The predict label is->1, the true label is->1
The predict label is->1, the true label is->1
The predict label is->1, the true label is->1
The predict label is->1, the true label is->1
The predict label is->1, the true label is->1
The predict label is->2, the true label is->2
The predict label is->2, the true label is->2
The predict label is->2, the true label is->2
The predict label is->2, the true label is->2
The predict label is->2, the true label is->2
The predict label is->3, the true label is->3
The predict label is->3, the true label is->3
The predict label is->3, the true label is->3
The predict label is->3, the true label is->3
The predict label is->3, the true label is->3
The predict label is->4, the true label is->4
The predict label is->4, the true label is->4
The predict label is->4, the true label is->4
The predict label is->4, the true label is->4
The predict label is->4, the true label is->4
The predict label is->5, the true label is->5
The predict label is->5, the true label is->5
The predict label is->5, the true label is->5
The predict label is->5, the true label is->5
The predict label is->5, the true label is->5
The predict label is->6, the true label is->6
The predict label is->6, the true label is->6
The predict label is->6, the true label is->6
The predict label is->6, the true label is->6
The predict label is->6, the true label is->6
The predict label is->7, the true label is->7
The predict label is->7, the true label is->7
The predict label is->7, the true label is->7
The predict label is->7, the true label is->7
The predict label is->7, the true label is->7
The predict label is->8, the true label is->8
The predict label is->8, the true label is->8
The predict label is->8, the true label is->8
The predict label is->8, the true label is->8
The predict label is->8, the true label is->8
The predict label is->9, the true label is->9
The predict label is->9, the true label is->9
The predict label is->9, the true label is->9
The predict label is->9, the true label is->9
The predict label is->9, the true label is->9
The accuracy is 1.0
本次结果的正确率为100%,可以说是相当不错了
以上及为KNN算法的全部过程及代码