java使用knn实现mnist_KNN 实现mnist数据集分类

本文介绍了使用Java实现KNN(K-Nearest Neighbors)算法对MNIST手写数字数据集进行分类的过程。首先,对训练数据集和验证数据集进行预处理,包括删除说明行、数据标准化。接着详细展示了KNN算法的两种实现方式,通过计算欧几里得距离找到最近的k个邻居,并根据类别最多的类别作为预测结果。最后,通过计算召回率评估模型性能,并补充了Python和Numpy的相关语法细节。
摘要由CSDN通过智能技术生成

一 数据预处理

训练数据集和验证数据集分别为train.csv和test.csv。数据集下载地址:http://pan.baidu.com/s/1eQyIvZG

要分别对训练数据集和验证数据集进行分析,分析其内部数据的特征,下面分别对两个数据集进行处理:

1.1 训练数据集处理

train.csv 里面结构为42001 * 785。其中第一行为文字说明,应该去掉,其余每一行均表示一个图像,大小为28*28,共784个像素值;第一列为类标签,每一个标签表示一个图像所代表的数字,范围为0-9;所以处理的步骤为:把所有数据存入列表中;删除第一行,得到42000*785;分离开第一列和剩余数据,分别得到42000*1和42000*784两个矩阵。

具体代码如下:

def loadtraindata(trainfile):#传参为所读文件名

l = list()#创建序列,要保存文件内容

with open(trainfile,'rb') as filename:

lines=csv.reader(filename)for line inlines:

l.append(line)del l[0]#删除第一行

l = np.array(l)#转换为数组

label = l[:,0]#取数组内所有行第一列元素

data = l[:,1:]#取数组内所有行,从第二列至最后列元素

label = np.int32(label)#int32为numpy 内部函数,进行数据类型转换

data = nomalizing(np.int32(data))#nomalizing 为自定义函数,进行数据标准化

returndata,label

标准化函数代码如下:

defnomalizing(array):

m,n= np.shape(array)#shape函数为得到数组的各个维度

for i inxrange(m):for j inxrange(n):if array[i,j] !=0:

array[i,j]= 1

return array

二 KNN实现分类

现已知训练集中有42000组元素和对应每组的类别,现给出一个未知类别的一组元素,要求预测其类别。KNN的做法是:找到与该组元素最近的k组;找到这k组元素里类别相同数最多的一个类别;认为该类别就是该未知类别元素的类别;

2.1 KNN具体代码如下:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值