Tensorflow MNIST原始图片TFRecord方式识别---1. 原始图片生成TFRecord文件
想做一个完整的tensorflow手写数字识别。计划步骤如下:
1)原始图片,数据处理,生成TFRecord文件
2)设计CNN MNIST手写数字识别的模型
3)从TFRecord文件中,提取图片数据,进行训练
4)取一张手写数字图片进行测试
1. 原始图片生成TFRecord文件
1.1 原始图片准备
在各大学习资源上,基本都是标准的Tensorflow MNIST数据集,如下:
我需要的是原始png图片。参考了https://blog.csdn.net/ITBigGod/article/details/83788865。此博客对标准的MNIST 数据集,进行了图片的提取。
1.2 手动打乱数据
没有这一步的时候,用顺序写入生成tfrecord文件,做过完整的训练和测试,识别准确率惨不忍睹。和标准MNIST数据集mnist.train.next_batch(…)相比较:
1)标准mnist数据集每次提取出来的数据,包含多了类别的;
2)原始mnist数据图片,经tfrecord文件,通过shuffle_batch随机提取batch_size大小的数据,基本每次都是同类别。
猜测是这方面的差别,导致训练的模型很差,最终识别准确率低。
基于此猜想,暂时想到笨的方法:手动打乱数据集【注意:提前备份原始图片数据】。具体方式:将0-9所有的图片文件,以“1-1000随机整数 + 当前时间分秒 + 识别数字”重新命名。比如:10000201573510.png,最后一位是0,代表这个类别是手写数字0,方便以此根据,在写tfrecord文件时,写到键值label中。
import os
import random
import datetime
origin_path = '../datasets/MNIST_PNG_Data/train/'
target_path = '../datasets/mnist_shuffle/'
for num in range(10):
num_path = os.path.join(origin_path, str(num))
for file in os.listdir(num_path):
oldname = os.path.join(num_path, file)
newfile = str(random.randint(1,10000)) + \
str(datetime.datetime.now())[-9:].replace('.', '') + \
str(num) + \
'.png'
newname = os.path.join(target_path, newfile)
print("oldname: ", oldname)
print("newname: ", newname)
os.rename(oldname, newname)
1.3 生成TFRecord文件
有关TFRecord相关介绍,可以参考我的另一篇博客《Tensorflow的TFRecord文件的读写学习》。生成
完整代码如下:
# -*- encoding = utf-8 -*-
import os
import tensorflow as tf
import cv2
ROOT_PATH = "../src/LeNet_Mnist_Origin/"
TFRECORD_PATH = os.path.join(ROOT_PATH, "tfrecord_data")
def gen_hand_shuffle_tfrecord(origin_data_path, tfrecord_file):
"""原始手写数字图片转化成tfreord文件"""
# zero_path = os.path.join(origin_data_path, "0/")
# one_path = os.path.join(origin_data_path, "1/")
if not os.path.exists(TFRECORD_PATH):
os.makedirs(TFRECORD_PATH)
record_writer = tf.python_io.TFRecordWriter(tfrecord_file)
num_examples = 0
for file in os.listdir(origin_data_path):
file_name = os.path.join(origin_data_path, file)
label = int(str(file)[-5])
print(label, ' ', file)
image = cv2.imread(file_name)
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
image = image.tostring()
example = tf.train.Example()
feature = example.features.feature
feature['image_raw'].bytes_list.value.append(image)
feature['label'].int64_list.value.append(label)
record_writer.write(example.SerializeToString())
num_examples += 1
record_writer.close()
return num_examples
if __name__ == '__main__':
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
hand_shuffle_mnist_data_path = '../datasets/mnist_shuffle/'
hand_shuffle_mnist_tfrecord_path = '../src/LeNet_Mnist_Origin/tfrecord_data/hand_shuffle_mnist.tfrecord'
if not os.path.isfile(hand_shuffle_mnist_tfrecord_path):
gen_hand_shuffle_tfrecord(hand_shuffle_mnist_data_path, hand_shuffle_mnist_tfrecord_path)
1.4 验证TFRecord文件数据
下一篇,Tensorflow MNIST原始图片TFRecord方式识别—2. 设计CNN MNIST手写数字识别的模型