公众号后台回复“图书“,了解更多号主新书内容
作者:小一
来源:小一的学习笔记
写在前面的话
大家好,我是小一
前面两节介绍了 k近邻算法,想必你对于k近邻算法的理论应该没啥问题
这节是k近邻算法的实战,小一感觉还比较好玩,刚好最近天池上有一个街景数字识别的比赛,改一改也可以去试试打比赛
ok,一起来看如何使用knn进行数字识别
开始实战
先来介绍一下项目的的背景
已有部分手写数字数据集,数据是存储在txt中,以一个32*32的矩阵存储。在矩阵中,有像素点的位置标为1,没有的标为0。每个txt文件的名称第一位数字表示该文件代表的书写数字
目前的训练数据1934个,测试数据946个,训练数据和测试数据并无较大区别
针对上面的背景,再结合上节的k近邻算法,建模的大致流程如下:
每个文件的32*32矩阵通过拼接存成一列,存入新的矩阵A,这样每一列表示一个图片
用测试数据和矩阵A的每一列求距离,求得的距离放在距离数组中
从距离数组中取出最小的k个距离对应的训练图片的标签
选取标签中的众数当做该预测图片的预测结果
方法也比较简单,顺着k近邻的流程撸一遍就大致明白了。
每个文件都是一张图,所有有很多个txt文件
截几个数据的图,在配上小一鬼斧神工的画技
读取数据
因为数据是小的txt文件,所以需要进行相应的处理
这里直接贴代码了,没什么技术难度
# 遍历文件
filenames = os.listdir(dirpath)
# 数据保存在二维数组中,标签保存在一维数组中
data_arr = np.zeros((len(filenames), 1024))
data_label = []
for i in range(len(filenames)):
filename = filenames[i]
# 读取每个文件的内容
filepath = os.path.join(dirpath, filename)
# 将32*32拼接成1*1024数组
data_arr[i, :] = concat_info(filepath)
data_label.append(filename[:1])
处理之后的数据现在长这样:
通过把二维数组拼接成一维数组,所以每一列都表示一个图片
训练数据是一个二维的矩阵,样本标签是一个一维的列表
开始建模
因为数据都是规整数据,所以省略EDA步骤
这一步需要选择最优的k值,因为不知道哪个k值最优,前面说过可以使用交叉验证确定最优
这里我们直接使用网格搜索让模型自己找到最优的k值
ok,网格搜索确定最优的k值为3,那我们直接建立模型
然后对测试集中的图片进行预测
准确率达到98%,但是还是有部分图片是预测错误的,把出错误的图片拎出来康康,到底是写成什么鬼样子了才会连knn都不认识
就这几个图片,看一下究竟是啥样子
一个原因是我们的样本数据不是很多,造成了识别错误,另一个也确实是手写的数字太......
考虑一下怎么解决这个问题:
增大样本量,增大模型的识别能力,但是也会相应的增大计算成本
设置权重,1的权重比0的权重高,所以01交接处的数字权重会高于连续数字的权重
好了,今天的项目就到这了
◆ ◆ ◆ ◆ ◆
麟哥新书已经在京东上架了,我写了本书:《拿下Offer-数据分析师求职面试指南》,目前京东正在举行双12活动,大家可以用相原价5折的价格购买,还是非常划算的:
数据森麟公众号的交流群已经建立,许多小伙伴已经加入其中,感谢大家的支持。大家可以在群里交流关于数据分析&数据挖掘的相关内容,还没有加入的小伙伴可以扫描下方管理员二维码,进群前一定要关注公众号奥,关注后让管理员帮忙拉进群,期待大家的加入。
管理员二维码:
猜你喜欢
● 你相信逛B站也能学编程吗