利用numpy读取mnist数据集

读取并分析如下四个文件

‘train-images-idx3-ubyte’
‘train-labels-idx1-ubyte’
‘t10k-images-idx3-ubyte’
‘t10k-labels-idx1-ubyte’
#_*_coding:utf-8_*_
import numpy as np
import os
class Mnist(object):

    def __init__(self):

        self.dataname = "Mnist"
        self.dims = 28*28
        self.shape = [28 , 28 , 1]
        self.image_size = 28
        self.data, self.data_y = self.load_mnist()

    def load_mnist(self):

        data_dir = os.path.join("./data", "mnist")
        fd = open(os.path.join(data_dir, 'train-images-idx3-ubyte'))
		# 利用np.fromfile语句将这个ubyte文件读取进来
		# 需要注意的是用np.uint8的格式
		# 还有读取进来的是一个一维向量
		# <type 'tuple'>: (47040016,),这就是loaded变量的读完之后的数据类型
        loaded = np.fromfile(file=fd , dtype=np.uint8)
		trX = loaded[16:].reshape((60000, 28 , 28 ,  1)).astype(np.float)
		#'train-images-idx3-ubyte'这个文件前十六位保存的是一些说明具体打印结果如下:
        point = loaded[:16]
        print(point)
        # [  0   0   8   3   0   0 234  96   0   0   0  28   0   0   0  28]
		# 序号从1开始,上述数字有下面这几个公式的含义
		# MagicNum = ((a(1)*256+a(2))*256+a(3))*256+a(4);
		# ImageNum = ((a(5)*256+a(6))*256+a(7))*256+a(8);    等于60000
		# ImageRow = ((a(9)*256+a(10))*256+a(11))*256+a(12); 等于28
		# ImageCol = ((a(13)*256+a(14))*256+a(15))*256+a(16);等于28

        fd = open(os.path.join(data_dir, 'train-labels-idx1-ubyte'))
        loaded = np.fromfile(file=fd, dtype=np.uint8)
        trY = loaded[8:].reshape((60000)).astype(np.float)
		
		
		point = loaded[:8]
		print(point)
		# [  0   0   8   1   0   0 234  96]
		# 这些数字的作用和上述类似
		# 这些数字的功能之一就是可以判断你下载的数据集对不对,全不全

        fd = open(os.path.join(data_dir, 't10k-images-idx3-ubyte'))
        loaded = np.fromfile(file=fd, dtype=np.uint8)
        teX = loaded[16:].reshape((10000, 28 , 28 , 1)).astype(np.float)

        fd = open(os.path.join(data_dir, 't10k-labels-idx1-ubyte'))
        loaded = np.fromfile(file=fd, dtype=np.uint8)
        teY = loaded[8:].reshape((10000)).astype(np.float)

        trY = np.asarray(trY)
        teY = np.asarray(teY)

        X = np.concatenate((trX, teX), axis=0)
        y = np.concatenate((trY, teY), axis=0)
		
		#目的是为了打乱数据集
		#这里随意固定一个seed,只要seed的值一样,那么打乱矩阵的规律就是一眼的
        seed = 666
        np.random.seed(seed)
        np.random.shuffle(X)
        np.random.seed(seed)
        np.random.shuffle(y)

        #convert label to one-hot
		#手动将数据转换成one-hot编码形式
        y_vec = np.zeros((len(y), 10), dtype=np.float)
        for i, label in enumerate(y):
            y_vec[i, int(y[i])] = 1.0

        return X / 255., y_vec
if __name__ == "__main__":
    #
    mn_object = Mnist()
    x = mn_object.data
    y = mn_object.data_y
    
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yuanCruise

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值