机器学习数据集-MNIST

转载 2016年05月31日 09:30:31

本文转载自:http://www.csuldw.com/2016/02/25/2016-02-25-machine-learning-MNIST-dataset/

介绍

在学习机器学习的时候,首当其冲的就是准备一份通用的数据集,方便与其他的算法进行比较。在这里,我写了一个用于加载MNIST数据集的方法,并将其进行封装,主要用于将MNIST数据集转换成numpy.array()格式的训练数据。直接下面看下面的代码吧(主要还是如何用python去读取binnary file)!

MNIST数据集原网址:http://yann.lecun.com/exdb/mnist/

Github源码下载:数据集(源文件+解压文件+字体图像jpg格式), py源码文件

文件目录

MNIST数据集解释

将MNIST文件解压后,发现这些文件并不是标准的图像格式。这些图像数据都保存在二进制文件中。每个样本图像的宽高为28*28。

mnist的结构如下,选取train-images

TRAINING SET IMAGE FILE (train-images-idx3-ubyte):

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  60000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel

首先该数据是以二进制存储的,我们读取的时候要以’rb’方式读取;其次,真正的数据只有[value]这一项,其他的[type]等只是来描述的,并不真正在数据文件里面。也就是说,在读取真实数据之前,我们要读取4个32 bit integer.由[offset]我们可以看出真正的pixel是从0016开始的,一个int 32位,所以在读取pixel之前我们要读取4个 32 bit integer,也就是magic number, number of images, number of rows, number of columns. 当然,在这里使用struct.unpack_from()会比较方便.

源码

说明:

  • ‘>IIII’指的是使用大端法读取4个unsinged int 32 bit integer
  • ‘>784B’指的是使用大端法读取784个unsigned byte

data_util.py文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# -*- coding: utf-8 -*-
"""
Created on Thu Feb 25 14:40:06 2016
load MNIST dataset
@author: liudiwei
"""
import numpy as np 
import struct
import matplotlib.pyplot as plt 
import os

class DataUtils(object):
    """MNIST数据集加载
    输出格式为:numpy.array()    
    
    使用方法如下
    from data_util import DataUtils
    def main():
        trainfile_X = '../dataset/MNIST/train-images.idx3-ubyte'
        trainfile_y = '../dataset/MNIST/train-labels.idx1-ubyte'
        testfile_X = '../dataset/MNIST/t10k-images.idx3-ubyte'
        testfile_y = '../dataset/MNIST/t10k-labels.idx1-ubyte'
        
        train_X = DataUtils(filename=trainfile_X).getImage()
        train_y = DataUtils(filename=trainfile_y).getLabel()
        test_X = DataUtils(testfile_X).getImage()
        test_y = DataUtils(testfile_y).getLabel()
        
        #以下内容是将图像保存到本地文件中
        #path_trainset = "../dataset/MNIST/imgs_train"
        #path_testset = "../dataset/MNIST/imgs_test"
        #if not os.path.exists(path_trainset):
        #    os.mkdir(path_trainset)
        #if not os.path.exists(path_testset):
        #    os.mkdir(path_testset)
        #DataUtils(outpath=path_trainset).outImg(train_X, train_y)
        #DataUtils(outpath=path_testset).outImg(test_X, test_y)
    
        return train_X, train_y, test_X, test_y 
    """


    def __init__(self, filename=None, outpath=None):
        self._filename = filename
        self._outpath = outpath
        
        self._tag = '>'
        self._twoBytes = 'II'
        self._fourBytes = 'IIII'    
        self._pictureBytes = '784B'
        self._labelByte = '1B'
        self._twoBytes2 = self._tag + self._twoBytes
        self._fourBytes2 = self._tag + self._fourBytes
        self._pictureBytes2 = self._tag + self._pictureBytes
        self._labelByte2 = self._tag + self._labelByte
    
    def getImage(self):
        """
        将MNIST的二进制文件转换成像素特征数据
        """
        binfile = open(self._filename, 'rb') #以二进制方式打开文件
        buf = binfile.read() 
        binfile.close()
        index = 0
        numMagic,numImgs,numRows,numCols=struct.unpack_from(self._fourBytes2,\
                                                                    buf,\
                                                                    index)
        index += struct.calcsize(self._fourBytes)
        images = []
        for i in range(numImgs):
            imgVal = struct.unpack_from(self._pictureBytes2, buf, index)
            index += struct.calcsize(self._pictureBytes2)
            imgVal = list(imgVal)
            for j in range(len(imgVal)):
                if imgVal[j] > 1:
                    imgVal[j] = 1
            images.append(imgVal)
        return np.array(images)
        
    def getLabel(self):
        """
        将MNIST中label二进制文件转换成对应的label数字特征
        """
        binFile = open(self._filename,'rb')
        buf = binFile.read()
        binFile.close()
        index = 0
        magic, numItems= struct.unpack_from(self._twoBytes2, buf,index)
        index += struct.calcsize(self._twoBytes2)
        labels = [];
        for x in range(numItems):
            im = struct.unpack_from(self._labelByte2,buf,index)
            index += struct.calcsize(self._labelByte2)
            labels.append(im[0])
        return np.array(labels)

    def outImg(self, arrX, arrY):
        """
        根据生成的特征和数字标号,输出png的图像
        """
        m, n = np.shape(arrX)
        #每张图是28*28=784Byte
        for i in range(1):
            img = np.array(arrX[i])
            img = img.reshape(28,28)
            outfile = str(i) + "_" +  str(arrY[i]) + ".png"
            plt.figure()
            plt.imshow(img, cmap = 'binary') #将图像黑白显示
            plt.savefig(self._outpath + "/" + outfile)

test.py文件:简单地测试了一下KNN算法,代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# -*- coding: utf-8 -*-
"""
Created on Thu Feb 25 16:09:58 2016
Test MNIST dataset 
@author: liudiwei
"""

from sklearn import neighbors  
from data_util import DataUtils
import datetime  


def main():
    trainfile_X = '../dataset/MNIST/train-images.idx3-ubyte'
    trainfile_y = '../dataset/MNIST/train-labels.idx1-ubyte'
    testfile_X = '../dataset/MNIST/t10k-images.idx3-ubyte'
    testfile_y = '../dataset/MNIST/t10k-labels.idx1-ubyte'
    train_X = DataUtils(filename=trainfile_X).getImage()
    train_y = DataUtils(filename=trainfile_y).getLabel()
    test_X = DataUtils(testfile_X).getImage()
    test_y = DataUtils(testfile_y).getLabel()

    return train_X, train_y, test_X, test_y 


def testKNN():
    train_X, train_y, test_X, test_y = main()
    startTime = datetime.datetime.now()
    knn = neighbors.KNeighborsClassifier(n_neighbors=3)  
    knn.fit(train_X, train_y)  
    match = 0;  
    for i in xrange(len(test_y)):  
        predictLabel = knn.predict(test_X[i])[0]  
        if(predictLabel==test_y[i]):  
            match += 1  
      
    endTime = datetime.datetime.now()  
    print 'use time: '+str(endTime-startTime)  
    print 'error rate: '+ str(1-(match*1.0/len(test_y)))  

if __name__ == "__main__":
    testKNN()

通过main方法,最后直接返回numpy.array()格式的数据:train_X, train_y, test_X, test_y。如果你需要,直接条用main方法即可!

机器学习笔记-MNIST数据库

如果你已经知道MNIST是什么,以及什么softmax(多项式逻辑)回归,那么你可能更喜欢这个更快节奏的教程。在开始任一教程之前,请务必安装TensorFlow。当学习如何编程时,有一个传统,你所做的...
  • jiaodui
  • jiaodui
  • 2017年04月29日 18:20
  • 363

【机器学习】MNIST数据集上的python读取和使用操作

MNIST手写字符数据集由LeCun大神提出。该数据集在机器学习中就相当于程序中的“Hello World”的存在。由于这个数据集可以很好测试我们的一些分类算法,本博客将对该数据集的读取操作等进行解释...
  • shwan_ma
  • shwan_ma
  • 2017年08月26日 16:45
  • 409

机器学习-实战-入门-MNIST手写数字识别

作者:橘子派 声明:版权所有,转载请注明出处,谢谢。 源码地址:https://github.com/sileixinhua/Python_Machine_Learning_Sklearn...
  • sileixinhua
  • sileixinhua
  • 2017年04月22日 21:49
  • 2256

(超详细)读取mnist数据集并保存成图片

mnist数据集介绍、读取、保存成图片 1、mnist数据集介绍: MNIST数据集是一个手写体数据集,简单说就是一堆这样东西  MNIST的官网地址是 MNIST; 通过阅读官网我们可以知...
  • YF_Li123
  • YF_Li123
  • 2017年08月05日 11:38
  • 5363

详解 MNIST 数据集

MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面来介绍一下.MNIST 数据集可在 http://yan...
  • simple_the_best
  • simple_the_best
  • 2017年07月17日 20:41
  • 24227

识别MNIST数据集之(一):读取数据

MINIST读数据
  • superCally
  • superCally
  • 2017年01月08日 20:43
  • 9376

深度学习数据集——MNIST

简单介绍MNIST数据集
  • zchang81
  • zchang81
  • 2017年03月13日 15:24
  • 814

MNIST数据集解析

官网一探 MNIST数据集是一个手写体数据集,简单说就是一堆这样东西 MNIST的官网地址是 MNIST; 通过阅读官网我们可以知道,这个数据集由四部分组成,分别是 ;也就是一个训练图片集,一个训练...
  • sysushui
  • sysushui
  • 2016年11月21日 10:49
  • 17653

机器学习数据集-MNIST

介绍在学习机器学习的时候,首当其冲的就是准备一份通用的数据集,方便与其他的算法进行比较。在这里,我写了一个用于加载MNIST数据集的方法,并将其进行封装,主要用于将MNIST数据集转换成numpy.a...
  • Dream_angel_Z
  • Dream_angel_Z
  • 2016年02月25日 18:46
  • 13092

tensorflow入门之mnist手写数据集识别

最近开始研究机器学习,整个模型都自己写的话不太现实,所以还是得用框架。几经查找,选择了Google的Tensorflow框架,这个起步也还比较好用,网上参考资料也很多。 参考的教程官网:[tensor...
  • cl6348
  • cl6348
  • 2017年11月01日 18:05
  • 247
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:机器学习数据集-MNIST
举报原因:
原因补充:

(最多只允许输入30个字)