Tensorflow2.x读取手写体数字识别MNIST数据集

最近在积攒粉丝500,大家帮帮忙,动动小手指关注、点赞、收藏…🙏🙏🙏🙏🙏🙏


一、了解数据集的结构

MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片,也包含了每一张图像的标签(告诉我们图片是数字几):
在这里插入图片描述

MNIST数据集网上下载得到的分为训练数据集和测试数据集两部分(数据集的数据都是由图片数据集和对应的标签数据集组成)。

其下载地址: http://yann.lecun.com/exdb/mnist/
下载得到包含如下几个文件:
在这里插入图片描述
在这里插入图片描述

训练集:用于训练的数据。

  • train-images-idx3-ubyte.gz: 包含6万张28*28图片。
  • train-labels-idx1-ubyte.gz:包含6万张图片的标签,即每一张是什么数字。

测试集:用于测试模型的泛化能力。

  • t10k-images-idx3-ubyte.gz: 包含1万张28*28图片。
  • t10k-labels-idx1-ubyte.gz:包含1万张图片的标签,即每一张是什么数字。

二、读取代码

解压标签文件和数据图像文件,然后再调用tensorflow2.x的数据集API进行预处理、洗牌、分批次等。

2.1 定义解压图像文件函数

def load_images(filename):
    """load images
    filename: the name of the file containing data
    return -- a matrix containing images as row vectors
    """
    g_file = gzip.GzipFile(filename)
    data = g_file.read()
    magic, num, rows, columns = struct.unpack('>iiii', data[:16])
    dimension = rows*columns
    X = np.zeros((num,rows,columns), dtype='uint8')
    offset = 16
    for i in range(num):
        a = np.frombuffer(data, dtype=np.uint8, count=dimension, offset=offset)
        X[i] = a.reshape((rows, columns))
        offset += dimension
    return X

2.2 定义解压标签文件函数

def load_labels(filename):
    """load labels
    filename: the name of the file containing data
    return -- a row vector containing labels
    """
    g_file = gzip.GzipFile(filename)
    data = g_file.read()
    magic, num = struct.unpack('>ii', data[:8])
    d = np.frombuffer(data,dtype=np.uint8, count=num, offset=8)
    return d

2.3 调用以上两个函数把4个压缩文件解压

def load_data(foldername):
    """加载MINST数据集
    foldername: the name of the folder containing datasets
    return -- train_X训练数据集, train_y训练数据集对应的标签,
        test_X测试数据集, test_y测试数据集对应的标签
    """
    # filenames of datasets
    train_X_name = "train-images-idx3-ubyte.gz"
    train_y_name = "train-labels-idx1-ubyte.gz"
    test_X_name = "t10k-images-idx3-ubyte.gz"
    test_y_name = "t10k-labels-idx1-ubyte.gz"
    train_X = load_images(os.path.join(foldername, train_X_name))
    train_y = load_labels(os.path.join(foldername,train_y_name))
    test_X = load_images(os.path.join(foldername, test_X_name))
    test_y = load_labels(os.path.join(foldername, test_y_name))
    return train_X, train_y, test_X, test_y

2.4 Tensorflow2.x的接口读取数据集

调用tensorflow2.0数据集处理API,进行图片预处理、图像洗牌、分批次等。

def process_image(image, label):
    """ 图片预处理 """
    # m = image.shape[0] * image.shape[1]
    # image = tf.reshape(image, (m,))      # 全连接网络输入(768,); 2D卷积网络不需要这个转换
    label = tf.one_hot(label, depth=10)
    return image, label

def get_dataset(X, Y, batch_size=64):
    ds = tf.data.Dataset.from_tensor_slices((X, Y))
    ds = ds.map(process_image)
    ds = ds.shuffle(buffer_size=1024)
    ds = ds.batch(batch_size)
    return ds

2.5 测试

测试一下读取一个batch的数据:

if __name__ == "__main__":
    # 读取数据集
    train_X, train_y, test_X, test_y = load_data("./data/MNIST")
    train_dataset = get_dataset(train_X, train_y, batch_size=64)
    test_dataset = get_dataset(test_X, test_y, batch_size=64)
    # 打印查看
    for nbatch, (x, labels) in enumerate(train_dataset):
        print("train x:", x.shape)
        print("train labels:", labels.shape)
        break
    for nbatch, (x, labels) in enumerate(test_dataset):
        print("test x:", x.shape)
        print("test labels:", labels.shape)
        break

在这里插入图片描述


最近在积攒粉丝500,大家帮帮忙,动动小手指关注、点赞、收藏…🙏🙏🙏🙏🙏🙏

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI学长

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值