Python机器学习--K-近邻之手写数字识别(mnist数据集)


项目内容:

用K-近邻算法,对Mnist数据集完成手写数字的识别



主要内容:

  • 读取Mnist数据集
    • 获取图片数据的函数
    • 读取标签数据的函数
  • 分类函数classify0
  • 测试代码
  • 第三方库:numpy、matplotlib、operator、struct(后两个主要用于读取Mnist数据集用)

有关struct库使用的方法请自行百度了解。



代码:

MNist数据集的数据结构:

 代码:

from numpy import *
import struct
import matplotlib.pyplot as plt
import matplotlib
import operator
matplotlib.use('TKAgg')

# 读取mnist数据集的函数
# 获取图片数据的函数
def getDataImages(filename):
    with open(filename,'rb') as fn: #以二进制打开文件
        bin_data = fn.read() # 获取文件内容
    offset = 0
    fmt_header = '>iiii'     # 4个int大小的内容,
    # 获取前4个int大小的数据,包含魔数,图片数量,图片的大小
    magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
    print('魔数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))
    image_size = num_rows * num_cols # 每一张数字图片的大小
    offset += struct.calcsize(fmt_header)  # 获得数据在缓存中的指针位置,从前面介绍的数据结构可以看出,读取了前4行之后,指针位置(即偏移位置offset)指向0016。
    fmt_image = '>' + str(image_size) + 'B'  # 图像数据像素值的类型为unsigned char型,对应的format格式为B。这里还有加上图像大小784,是为了读取784个B格式数据,如果没有则只会读取一个值(即一副图像中的一个像素值)
    # print(fmt_image, offset, struct.calcsize(fmt_image))
    images = empty((num_images, image_size))
    for i in range(num_images):
        images[i,:] = array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((image_size))
        offset += struct.calcsize(fmt_image)
    return images

# 获取标签数据的函数
def getDataLabels(filename):
    with open(filename,'rb') as fn:
        bin_data = fn.read()
    offset = 0
    fmt_header = '>ii'
    magic_number, num_imageslabel = struct.unpack_from(fmt_header, bin_data, offset)
    print('魔数:%d, 图片标签数量: %d张' % (magic_number, num_imageslabel))
    offset += struct.calcsize(fmt_header)
    fmt_image = '>B'
    labels = empty(num_imageslabel)
    for i in range(num_imageslabel):
        # if (i + 1) % 10000 == 0:
        #     print('已解析 %d' % (i + 1) + '张')
        labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels

# 分类函数
def classify0(inX,dataSet,label,kNum):
    '''
    :param inX: 用于分类的输入向量
    :param dataSet: 数据集
    :param label: 数据集对应的标签
    :param kNum: k的数量
    :return:
    '''
    dataSetSize = dataSet.shape[0] #计算data行数,即数据集的数量

    diffMat = tile(inX,(dataSetSize,1))-dataSet#将inX扩展到跟dataSet同样大小,再相减
    sqDiffMat = diffMat ** 2#diffMat的每个值平方
    sqDistances = sqDiffMat.sum(axis=1) #每一行相加,计算行和
    distances = sqDistances ** 0.5#计算距离

    sortedDistIndicies = distances.argsort()#对距离进行从小到大排序,获得index
    classCount = {}
    for i in range(kNum):
        voteIlabel = label[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)#https://blog.csdn.net/weixin_30315905/article/details/96987587
    return sortedClassCount[0][0]

# 测试函数 #记得把文件路径改一下哦
if __name__ == '__main__':
    train_images_file = 'D:\迅雷下载\新建文件夹\\train-images.idx3-ubyte'
    trainImageData = getDataImages(train_images_file)
    train_labels_file = 'D:\迅雷下载\新建文件夹\\train-labels.idx1-ubyte'
    trainImageLabels = getDataLabels(train_labels_file)

    test_images_file = 'D:\迅雷下载\新建文件夹\\t10k-images.idx3-ubyte'
    test_labels_file = 'D:\迅雷下载\新建文件夹\\t10k-labels.idx1-ubyte'
    testImageData = getDataImages(test_images_file)
    testImageLabels = getDataLabels(test_labels_file)

    errorCount = 0
    for i in range(len(testImageLabels)):
        resultClassify = classify0(testImageData[i, :], trainImageData, trainImageLabels, 8)
        if resultClassify != testImageLabels[i]:
            errorCount += 1
            print("**********************++++++++++++++++++++++*******************************")
        print("测试集第%d,真实数字为:%s;KNN分类得到的结果为:%s"%(i,testImageLabels[i],resultClassify))
        if (i + 1) % 100 == 0:
            print("+++++++++++++++++++++++++++++在测试集中,共 %d 个测试的错误率为:%f" % (i,errorCount / i))
    print("在测试集中错误率为:%f" % (errorCount / testImageData.shape[0]))


    ##下面代码用于展示图片数据,x表示数据集中的第几张照片,替换为一个数字即可
    # p = trainImageData[x,:].reshape(28,28)
    # plt.imshow(p, 'gray')
    # plt.show()
    # print(trainImageLabels[1])



运行结果:

  • 0
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值