在TensorFlow的官方入门课程中,多次用到mnist数据集。
mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx3-ubyte的二进制文件。
如果我们想要知道大名鼎鼎的mnist手写体数字都长什么样子,就需要从mnist数据集中导出手写体数字图片。了解这些手写体的总体形状,也有助于加深我们对TensorFlow入门课程的理解。
下面先给出通过TensorFlow api接口导出mnist手写体数字图片的python代码,再对代码进行分析。代码在win7下测试通过,linux环境也可以参考本处代码。
(非常良心的注释和打印有木有)
#!/usr/bin/python3.5
# -*- coding: utf-8 -*-
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from PIL import Image
# 声明图片宽高
rows = 28
cols = 28
# 要提取的图片数量
images_to_extract = 8000
# 当前路径下的保存目录
save_dir = "./mnist_digits_images"
# 读入mnist数据
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
# 创建会话
sess = tf.Session()
# 获取图片总数
shape = sess.run(tf.shape(mnist.train.images))
images_count = shape[0]
pixels_per_image = shape[1]
#