mnist数据集学习

在学习机器学习的过程中,数据集是进行机器学习的必备材料,我在学习分类算法的时候接触到了MNIST数据集,这是Yann LeCun教授贡献的,具体下载可以通过这个链接http://yann.lecun.com/exdb/mnist/

这个数据集包含4个部分,

文件内容
train-images-idx3-ubyte.gz训练图片60000张
train-labels-idx3-ubyte.gz训练图片的标记60000个
t10k-images-idx3-ubyte.gz测试集图片10000张
t10k-labels-idx3-ubyte.gz测试集图片的标记10000个

图片是以字节形式存储的,我们需要将它们读取到NumPy array中,以便训练和测试算法。代码如下:

import os
import struct
import numpy as np

def load_mnist(path,kind='train'):
    """load MNIST data from 'path'"""
    labels_path = os.path.join(path,'%s-labels-idxs-ubyte'%kind)
    images_path = os.path.join(path,'%s-images-idxs-ubyte'%kind)

    with open(labels_path,'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8))
        labels = np.fromfile(lbpath,dtype=np.uint8)

    with open(images_path,'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII',imgpack.read(16))
        images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels),784)
    
    return images, labels



      

load_mnist函数返回两个数组,第一个是一个n*m维的NumPy array(images),这里的n是样本数,m是特征数,训练数据集包含60,000个样本,测试数据集包含10,000个样本,在MNIST数据集中每张图片由28*28的像素点组成,每个像素点由一个灰度值表示,在这里我们将28*28的像素点展成为一个一维的行向量,每一行代表一张图片;返回的第二个数组中包含了相应的目标变量,也就是类标签。

为了了解MNIST中图片的样子,用matplotlib进行可视化处理。

import matplotlib.pyplot as plt

fig, ax = plt.subplot(nrows=2,ncols=5,sharex=True,sjarey=True,)
ax = ax.flatten()

for i in range(10):
    img = X_train[y_train==i][0].reshape(28,28)

    ax[i].imshow(img,cmap='Grays',interpolation='nearest')

ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()


 

 

 

 

 

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值