Tensorflow MNIST原始图片TFRecord方式识别---1. 原始图片生成TFRecord文件

Tensorflow MNIST原始图片TFRecord方式识别---1. 原始图片生成TFRecord文件


想做一个完整的tensorflow手写数字识别。计划步骤如下:
1)原始图片,数据处理,生成TFRecord文件
2)设计CNN MNIST手写数字识别的模型
3)从TFRecord文件中,提取图片数据,进行训练
4)取一张手写数字图片进行测试

1. 原始图片生成TFRecord文件

1.1 原始图片准备

在各大学习资源上,基本都是标准的Tensorflow MNIST数据集,如下:
在这里插入图片描述
我需要的是原始png图片。参考了https://blog.csdn.net/ITBigGod/article/details/83788865。此博客对标准的MNIST 数据集,进行了图片的提取。

1.2 手动打乱数据

没有这一步的时候,用顺序写入生成tfrecord文件,做过完整的训练和测试,识别准确率惨不忍睹。和标准MNIST数据集mnist.train.next_batch(…)相比较:
1)标准mnist数据集每次提取出来的数据,包含多了类别的;
2)原始mnist数据图片,经tfrecord文件,通过shuffle_batch随机提取batch_size大小的数据,基本每次都是同类别。
猜测是这方面的差别,导致训练的模型很差,最终识别准确率低。
基于此猜想,暂时想到笨的方法:手动打乱数据集【注意:提前备份原始图片数据】。具体方式:将0-9所有的图片文件,以“1-1000随机整数 + 当前时间分秒 + 识别数字”重新命名。比如:10000201573510.png,最后一位是0,代表这个类别是手写数字0,方便以此根据,在写tfrecord文件时,写到键值label中。

import os
import random
import datetime

origin_path = '../datasets/MNIST_PNG_Data/train/'
target_path = '../datasets/mnist_shuffle/'

for num in range(10):
    num_path = os.path.join(origin_path, str(num))
    for file in os.listdir(num_path):
        oldname = os.path.join(num_path, file)
        newfile = str(random.randint(1,10000)) + \
            str(datetime.datetime.now())[-9:].replace('.', '') + \
            str(num) + \
            '.png'
        newname = os.path.join(target_path, newfile)
        print("oldname: ", oldname)
        print("newname: ", newname)
        os.rename(oldname, newname)

在这里插入图片描述

1.3 生成TFRecord文件

有关TFRecord相关介绍,可以参考我的另一篇博客《Tensorflow的TFRecord文件的读写学习》。生成
完整代码如下:

# -*- encoding = utf-8 -*-
import os
import tensorflow as tf
import cv2

ROOT_PATH = "../src/LeNet_Mnist_Origin/"
TFRECORD_PATH = os.path.join(ROOT_PATH, "tfrecord_data")

def gen_hand_shuffle_tfrecord(origin_data_path, tfrecord_file):
    """原始手写数字图片转化成tfreord文件"""
    # zero_path = os.path.join(origin_data_path, "0/")
    # one_path = os.path.join(origin_data_path, "1/")
    if not os.path.exists(TFRECORD_PATH):
        os.makedirs(TFRECORD_PATH)
    record_writer = tf.python_io.TFRecordWriter(tfrecord_file)
    num_examples = 0
    for file in os.listdir(origin_data_path):
        file_name = os.path.join(origin_data_path, file)
        label = int(str(file)[-5])
        print(label, ' ', file)
        image = cv2.imread(file_name)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        image = image.tostring()
        example = tf.train.Example()
        feature = example.features.feature
        feature['image_raw'].bytes_list.value.append(image)
        feature['label'].int64_list.value.append(label)
        record_writer.write(example.SerializeToString())
        num_examples += 1

    record_writer.close()
    return num_examples

if __name__ == '__main__':
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    hand_shuffle_mnist_data_path = '../datasets/mnist_shuffle/'
    hand_shuffle_mnist_tfrecord_path = '../src/LeNet_Mnist_Origin/tfrecord_data/hand_shuffle_mnist.tfrecord'
    if not os.path.isfile(hand_shuffle_mnist_tfrecord_path):
        gen_hand_shuffle_tfrecord(hand_shuffle_mnist_data_path, hand_shuffle_mnist_tfrecord_path)

在这里插入图片描述

1.4 验证TFRecord文件数据

在这里插入图片描述

下一篇,Tensorflow MNIST原始图片TFRecord方式识别—2. 设计CNN MNIST手写数字识别的模型

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值