KNN算法识别手写数字

前言:从现在开始博主要开始学习机器学习了,欢迎有共同兴趣的人一起学习。废话不多说了,开始上代码

数据集我用的Kaggle上的数据,下载地址:https://www.kaggle.com/gb00000/keras-cnn-digit/data

由于这是第一篇bolg,所以我没有自己手写KNN算法,用的是 sklearn 库,具体的手写KNN算法我会在下一篇bolg上写。

一、将数据集csv文件导入到python中,并将csv格式转为list格式

def dataset():
    #读取训练集文件路径
    path =  '/trains1.csv'  #(os.path.abspath("../MachineLearningData/kNN/use Python and NumPy/trainingDigits"))
    with open(path,'r',encoding="utf-8") as f:
        reader = csv.reader(f)
        #[row for row in reader]叫做列表解析 根据已有列表,高效创建新列表的方式。  列表解析是Python迭代机制的一种应用,它常用于实现创建新的列表,因此用在[]中。
        #[row for row in reader] 等价于
        #rows =[]
        #for row in reader:
        #     rows.append(row)
        rows = [row for row in reader]
    #因为第一行是说明,从第二行开始才是数据,所以从第二行开始
    return rows[1:]

二、将返回的list进一步划分成:训练数据集,训练目标数据,测试数据集,测试目标数据

这里我将list进行分割,70%为训练集,30%为测试集

    data = dataset()
    #训练集
    data_train = data[len(data)//3+1:]
    trainingdata = [i[1:] for i in data_train]
    #第0列为目标数据集,获取第0列的数据
    trainingtargetdata = [i[0] for i in data_train]
    #测试集
    data_test = data[:len(data) // 3]
    testdata = [i[1:] for i in data_test]
    testtargetdata = [i[0] for i in data_test]

三、调用 KNeighborsClassifier 方法进行数据预测

#开始训练,这里我设置的K临近值为3
    knn = KNeighborsClassifier(n_neighbors=3)
    # 训练数据
    knn.fit(trainingdata, trainingtargetdata)
    prediction = knn.predict(testdata)
    knn.score(testdata,testtargetdata)
    right=0
    for test,result in zip(testtargetdata,prediction):
        if test == result:
            right += 1
        print('测试数据为:{},预测数据为:{}'.format(test,result))
    print('测试的准确率为->{}%'.format((right/len(testtargetdata))*100))

测试结果:kaggle上一共有好几万条数据,这里我只用了1000条,所以准确率不是那么高。

重新训练了3万条数据的结果,跑了20分钟,准确度比之前高了很多

完整代码:

import os
import numpy as np
import csv

from sklearn import neighbors
from sklearn.neighbors import KNeighborsClassifier
#define one read train.csv  function
def dataset():
    #读取训练集文件路径
    path =  '数据集路径'           #(os.path.abspath("../MachineLearningData/kNN/use Python and NumPy/trainingDigits"))
    with open(path,'r',encoding="utf-8") as f:
        reader = csv.reader(f)

        #列表解析  根据已有列表,高效创建新列表的方式。  列表解析是Python迭代机制的一种应用,它常用于实现创建新的列表,因此用在[]中。
        #[row for row in reader] 等价于
        #rows =[]
        #for row in reader:
        #     rows.append(row)
        rows = [row for row in reader]
    return rows[1:]

def main():
    data = dataset()
    #训练集
    data_train = data[len(data)//3+1:]
    trainingdata = [i[1:] for i in data_train]
    trainingtargetdata = [i[0] for i in data_train]
    #测试集
    data_test = data[:len(data) // 3]
    testdata = [i[1:] for i in data_test]
    testtargetdata = [i[0] for i in data_test]

    #开始训练
    knn = KNeighborsClassifier(n_neighbors=3)
    # 训练数据
    knn.fit(trainingdata, trainingtargetdata)
    prediction = knn.predict(testdata)
    knn.score(testdata,testtargetdata)
    right=0
    for test,result in zip(testtargetdata,prediction):
        if test == result:
            right += 1
        print('测试数据为:{},预测数据为:{}'.format(test,result))
    print('测试的准确率为->{}%'.format((right/len(testtargetdata))*100))
if __name__ == '__main__':
    main()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值