基于k近邻的MNIST图像分类对比

数据集读取

    由于数据来源网站不稳定,个人将数据集下载到本地后进行读取

网上多数都是将数据集读取为三维数组方便进行显示,但因计算方便和用sklearn时都是二维数组,所以个人后来修改了下

def decode_idx3_ubyte(idx3_ubyte_file):
    """
    解析idx3文件的通用函数
    :param idx3_ubyte_file: idx3文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = gzip.open(idx3_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
    offset = 0
    fmt_header = '>IIII'

    # 解析数据集
    offset += struct.calcsize(fmt_header)
    fmt_image = '>784B'
    image_size = 100
    
    # 判断是否是训练集
    if 'train' in idx3_ubyte_file:
        image_size = 6000        
    images = np.empty((image_size, 784))
    for i in range(image_size):
        temp = struct.unpack_from(fmt_image, bin_data, offset)
        images[i] = np.reshape(temp, 784)
        offset += struct.calcsize(fmt_image)
    return images


def decode_idx1_ubyte(idx1_ubyte_file):
    """
    解析idx1文件的通用函数
    :param idx1_ubyte_file: idx1文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = gzip.open(idx1_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数和标签数
    offset = 0
    fmt_header = '>II'

    # 解析数据集
    offset += struct.calcsize(fmt_header)
    fmt_label = '>B'
    label_size = 100
    
    # 判断是否是训练集
    if 'train' in idx1_ubyte_file:
        label_size = 6000        
    labels = np.empty(label_size, np.int)
    for i in range(label_size):
        labels[i] = struct.unpack_from(fmt_label, bin_data, offset)[0]
        offset += struct.calcsize(fmt_label)
    return labels

这里控制了读取的数量,只使用了原数据集的十分之一

实现k近邻算法

class NearstNeighbour:
    def __init__(self, k):
        self.k = k
    
    def train(self, X, y):
        self.Xtr = X
        self.ytr = y
        return self
    
    def predict(self, test_images):
        predictions = []

		# 这段代码借鉴https://github.com/Youngphone/KNN-MNIST/blob/master/KNN-MNIST.ipynb
        # 当前运行的测试用例坐标
        for test_item in test_images:
            datasetsize = self.Xtr.shape[0]
            #距离计算公式
            diffMat = np.tile(test_item, (datasetsize, 1)) - self.Xtr
            sqDiffMat = diffMat ** 2
            sqDistances = sqDiffMat.sum(axis = 1)
            distances = sqDistances ** 0.5
            # 距离从大到小排序,返回距离的序号
            sortedDistIndicies = distances.argsort()
            # 字典
            classCount = {}
            # 前K个距离最小的
            for i in range(self.k):
                # sortedDistIndicies[0]返回的是距离最小的数据样本的序号
                # labels[sortedDistIndicies[0]]距离最小的数据样本的标签
                voteIlabel = self.ytr[sortedDistIndicies[i]]
                # 若属于某类则权重加一
                classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
            # 排序
            sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
            predictions.append(sortedClassCount[0][0])

        return predictions

与sklearn的k近邻对比

# -*- encoding: utf-8 -*-
'''
@File    :   NearstNeighbour.py
@Time    :   2021/03/27 15:40:05
@Author  :   Wihau 
@Version :   1.0
@Desc    :   None
'''

# here put the import lib
import gzip
import numpy as np
import struct
import operator
import time

from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix


train_images_idx3_ubyte_file = 'train-images-idx3-ubyte.gz'
train_labels_idx1_ubyte_file = 'train-labels-idx1-ubyte.gz'
test_images_idx3_ubyte_file = 't10k-images-idx3-ubyte.gz'
test_labels_idx1_ubyte_file = 't10k-labels-idx1-ubyte.gz'

def decode_idx3_ubyte(idx3_ubyte_file):
    """
    解析idx3文件的通用函数
    :param idx3_ubyte_file: idx3文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = gzip.open(idx3_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
    offset = 0
    fmt_header = '>IIII'

    # 解析数据集
    offset += struct.calcsize(fmt_header)
    fmt_image = '>784B'
    image_size = 100
    
    # 判断是否是训练集
    if 'train' in idx3_ubyte_file:
        image_size = 6000        
    images = np.empty((image_size, 784))
    for i in range(image_size):
        temp = struct.unpack_from(fmt_image, bin_data, offset)
        images[i] = np.reshape(temp, 784)
        offset += struct.calcsize(fmt_image)
    return images


def decode_idx1_ubyte(idx1_ubyte_file):
    """
    解析idx1文件的通用函数
    :param idx1_ubyte_file: idx1文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = gzip.open(idx1_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数和标签数
    offset = 0
    fmt_header = '>II'

    # 解析数据集
    offset += struct.calcsize(fmt_header)
    fmt_label = '>B'
    label_size = 100
    
    # 判断是否是训练集
    if 'train' in idx1_ubyte_file:
        label_size = 6000        
    labels = np.empty(label_size, np.int)
    for i in range(label_size):
        labels[i] = struct.unpack_from(fmt_label, bin_data, offset)[0]
        offset += struct.calcsize(fmt_label)
    return labels

def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
    return decode_idx3_ubyte(idx_ubyte_file)

def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
    return decode_idx1_ubyte(idx_ubyte_file)

def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
    return decode_idx3_ubyte(idx_ubyte_file)

def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
    return decode_idx1_ubyte(idx_ubyte_file)

class NearstNeighbour:
    def __init__(self, k):
        self.k = k
    
    def train(self, X, y):
        self.Xtr = X
        self.ytr = y
        return self
    
    def predict(self, test_images):
        predictions = []

        # 当前运行的测试用例坐标
        for test_item in test_images:
            datasetsize = self.Xtr.shape[0]
            #距离计算公式
            diffMat = np.tile(test_item, (datasetsize, 1)) - self.Xtr
            sqDiffMat = diffMat ** 2
            sqDistances = sqDiffMat.sum(axis = 1)
            distances = sqDistances ** 0.5
            # 距离从大到小排序,返回距离的序号
            sortedDistIndicies = distances.argsort()
            # 字典
            classCount = {}
            # 前K个距离最小的
            for i in range(self.k):
                # sortedDistIndicies[0]返回的是距离最小的数据样本的序号
                # labels[sortedDistIndicies[0]]距离最小的数据样本的标签
                voteIlabel = self.ytr[sortedDistIndicies[i]]
                # 若属于某类则权重加一
                classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
            # 排序
            sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
            predictions.append(sortedClassCount[0][0])

        return predictions

train_images = load_train_images()
train_labels = load_train_labels()
test_images = load_test_images()
test_labels = load_test_labels()



k = 5
# 个人k近邻预测
print("-----Personal k nearest neighbour-----")
# 预测时间
start = time.time()
knn = NearstNeighbour(k)
predictions = knn.train(train_images, train_labels).predict(test_images)
end = time.time()
print("time of prediction:%.3f s" % (end-start))
# 准确率
accuracy = accuracy_score(test_labels, predictions)
print("accuracy score:", accuracy)
# 混淆矩阵
matrix = confusion_matrix(test_labels, predictions)
print(matrix)

# sklearn的k近邻预测
print("-----Sklearn nearest neighbour-----")
# 预测时间
start = time.time()
sknn = KNeighborsClassifier(n_neighbors = k)
skpredictions = sknn.fit(train_images, train_labels).predict(test_images)
end = time.time()
print("time of prediction:%.3f s" % (end-start))
# 准确率
skaccuracy = accuracy_score(test_labels, skpredictions)
print("accuracy score:", skaccuracy)
# 混淆矩阵
skmatrix = confusion_matrix(test_labels, skpredictions)
print(skmatrix)

结果如下

k = 5 时
在这里插入图片描述
k = 10 时
在这里插入图片描述

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值