手工实现KNN算法预测mnist数据集

KNN(K-Nearest Neighbor)算法的优劣
KNN算法是机器学习最基础的算法,它的基本原理就是找到训练数据集里面离需要预测的样本点“距离最近”的k个对象,取其中出现最多的标签作为预测值。
其他更先进的机器学习算法是在训练集上花大量时间训练出一个模型,预测时只要用这个模型直接快速预测,而无需再去处理训练集。而KNN算法恰好相反,其没有训练过程,但在预测过程中要遍历训练集,因而预测花费较大。

优点:
(1) 核心思想简单易懂
(2) 容易实现
(3) 没有训练过程

缺点:
(1) 预测过程时间空间复杂度过高,造成无法高效预测,处理较大的样本量需要的时间非常多
(2) 预测准确度并不高,只能在较简单的图片上work,较复杂的图片则不行
(3) 对训练数据的容错率较低。如果训练集中,出现错误的几个数据都在需要分类的数值的旁边,这样就会导致预测偏差
(4) 当训练集不平衡时,预测结果可能会偏向训练集中数量较多的类别

前置条件:
(1)安装Python的numpy模块和图像处理库Pillow
(2)将mnist数据集的4个文件下载并解压到py文件同一目录下
(3)如需要测试自己的手写数字图片,将图片保存为png格式,命名为test,并放在py文件同一目录下

K值和距离类型是KNN算法的两个超参:
(1)K值太小会出现过拟合,此时建立的模型复杂度高,决策边界崎岖;
(2)K值太大会出现欠拟合,此时建立的模型简单,决策边界平滑。
(3)距离类型的选取需要自行调参,确定最优情况。

这里只取了测试集中前100个样本进行预测,根据正确率的变化来调参(K值不能太小和太大,又要满足较高的预测准确率)。

测试情况如下:
K=1:
L1 98%
L2 100%

K=2:
L1 97%
L2 97%

K=3:
L1 98%
L2 99%

K=4:
L1 96%
L2 99%

K=5:
L1 98%
L2 99%

K=6:
L1 97%
L2 99%

代码:

import numpy as np
import struct
from PIL import Image

def convertImage(path): # 返回一个ndarray,每一行是一个图片,同时返回图片尺寸
    buff = open(path, "rb").read() # 获得文件对象句柄后全部读进来
    fmt = ">iiii" # 大端模式,读四个32位带符号整数
    magicNumber, number, rows, cols = struct.unpack_from(fmt, buff, 0) # struct.unpack_from(fmt=,buffer=,offset=) 返回一个tuple
    retList = []
    cur = struct.calcsize(fmt) # 计算格式占的大小
    fmt = ">" + "B" * rows * cols # 大端模式,读28*28个1位无符号整数
    for i in range(number):
        retList.append(struct.unpack_from(fmt, buff, cur)) # ndarray新增元素
        cur += struct.calcsize(fmt)
    return np.array(retList), rows, cols # 返回一个tuple

def convertLabel(path):
    buff = open(path, "rb").read() # rb读的方式打开二进制文件
    fmt = ">ii"
    magicNumber, number = struct.unpack_from(fmt, buff, 0)
    retList = []
    cur = struct.calcsize(fmt)
    fmt = ">" + "B"
    for i in range(number):
        retList.append(struct.unpack_from(fmt, buff, cur)[0])
        cur += struct.calcsize(fmt)
    return retList

def convertSingleImage(path, rows, cols):
    img = Image.open(path)
    pix = img.load()
    retArr = np.empty(rows * cols)
    for i in range(rows):
        for j in range(cols):
            R, G, B, A= pix[j, i] # RGBA,Python中A在最后面
            retArr[i * cols + j] = 255 - (R + G + B) / 3
    return retArr

def KNN(testImg):
    if disType==1:
        disArr = np.sum(np.absolute(traImgArr-testImg), axis = 1) # traImgArr是二维60000*784的,testImg是一维784的,直接相减即可让traImgArr的每一行都与testImg相减
    if disType==2:
        disArr = np.sum(np.square(traImgArr-testImg), axis = 1) # traImgArr是二维60000*784的,testImg是一维784的,直接相减即可让traImgArr的每一行都与testImg相减
    nei = np.empty(K)
    nei.dtype  = "int64" # 把nei的数据类型强行变成int64,防止往下第三行的地方数据转化时出错
    for i in range(K):
        pos = np.argmin(disArr) # 距离最小值的索引
        nei[i] = traLblList[pos]
        disArr[pos] = MAX
    arr = np.bincount(nei) # 统计出现次数
    return np.argmax(arr) # 返回出现次数最多的

def testData(expSize): # 预测测试集并查看正确率
    corrCnt = 0 # 预测正确的个数
    testSize = min(np.size(testImgArr, 0), expSize) # 不超过测试集总个数
    
    for i in range(testSize):
        testNum = KNN(testImgArr[i])
        #print("No.", i, "识别值:", testNum, "真实值:", testLblList[i])
        if testNum == testLblList[i]:
            corrCnt += 1
    print("测试集正确率:", corrCnt / testSize * 100,"%")

def testLocalImg(): # 预测本地的图片
    img = convertSingleImage(testPath, rows, cols)
    print("本地图片预测结果:", KNN(img))


#输入超参
K = int(input("请输入K值:")) # KNN的K
disType=int(input("请输入距离类型(1为曼哈顿距离,2为欧几里得距离):")) # 距离的类型

MAX = 10**9 # 常量,用来求K最近邻
expSize=100 # 要测的测试集的数量

# 数据集的预处理
traImgPath = "train-images.idx3-ubyte"
traLblPath = "train-labels.idx1-ubyte"
testImgPath = "t10k-images.idx3-ubyte"
testLblPath = "t10k-labels.idx1-ubyte"
testPath = "test.png"

traImgArr, rows, cols = convertImage(traImgPath)
traLblList = convertLabel(traLblPath)
testImgArr, rows, cols = convertImage(testImgPath)
testLblList = convertLabel(testLblPath)

#预测本地图片以及测试集图片
testLocalImg()
testData(expSize) 
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值