KNN算法又称K邻近算法(K Nearest Neighbor),其基本思想为:样本空间中,某样本的类别为距离其最近的k个邻居所属类别中最多的那个类别。
MNIST数据集为一个带标注的手写数字识别的数据集,其官方下载地址为http://yann.lecun.com/exdb/mnist/。数据集包含60000个训练集和10000个测试集。该数据集中的文件以二进制方式保存,因此在读入时需要以二进制方式打开。
在学习读入MNIST数据集过程中,最大的收获是对Python中struct包的简单的了解。其中,对MNIST数据集的读入,主要参考博客http://www.cnblogs.com/x1957/archive/2012/06/02/2531503.html。个人觉得这一篇讲的还是很清楚的。
至于算法实现,有两个收获:
1. 算法优化:初始算法步骤比较冗长,肆无忌惮的使用for循环。这种做法对于MNIST这种数据量相对较大的数据集还是很致命的。改进后的算法,其运行时间明显减少(虽然还是很大)。
2. Numpy库中的argsort(list):以前对列表进行排序后想要获取对应的列表的索引,各种index,而且容易出现值重合的现象。但是,argsort(list)真的是神器。它对列表的值进行排序(降序)后,返回的列表值对应的索引。有木有很方便!
下面就是代码啦(Python):
读入图片,转化为一个数组(每一行为对应图片的像素):
# read the image file
# input: file path
#output: the list of piexl array for each image
def read_image(file_path):
f_open=open(file_path,"rb")
content=f_open.read()
index=0
magic, num_images,num_rows,num_columns=struct.unpack_from(">IIII",content,index) # 以大端法读入四个unsigned int
print("number of images:"+str(num_images))
print("number of rows:"+str(num_rows))
print("number of columns:"+str(num_columns))
index+=struct.calcsize(">IIII")
img_piexl=[]
for i in range(num_images):
piexl_all=[]
for j in range(num_columns):
for k in range(num_rows):
piexl=struct.unpack_from(">B",content,in