tfrecord生成

验证码预处理

在我们训练验证码识别的网络的时候我们需要进行验证码数据的生成和生成训练所需要的tfrecord文件,所以博主这次讲解的内容是进行验证码网络训练过程中的图像和数据生成的预操作,让你生成你训练所需要的数据。



导入数据包

这里我们导入一个特殊的库,这个库是用来生成我们验证码图像的,我们基础库里面是没有的,我们需要提前进行下载

我们打开cmd命令提示符,我们在里面输入:pip install captcha 下载这个库


# 验证码生成器
from captcha.image import ImageCaptcha
import numpy as np
from PIL import Image
import random
import sys
import os


一、 生成四位验证码

我们这里使用的是十个数字**(0-9)**来生成我们需要的数据,我们使用一个函数来生成,遍历四个数,我们随机从我们的列表里面挑选四个数字,然后返回到我们的验证码列表中。


number = list()
for i in range(10):
    number.append(str(i))
# alphact = [a-z]
# alpharet = [A-Z]

# 生成四位验证码
def random_captcha_text(char_set=number, captcha_size=4):
    # 验证码列表
    captcha_text = []
    for i in range(captcha_size):
        # 随机选择
        c = random.choice(char_set)
        # 加入验证码列表
        captcha_text.append(c)
    return captcha_text



二、 生成对应的验证码

在这里我们生成对应的验证码,使用我们的库函数ImageCaptcha来生成,写入我们用数字名字定义的文件中并进行保存。


def gen_captcha_text_and_image():
    image = ImageCaptcha()  # 生成验证码对象
    # 获得随机生成的验证码
    captcha_text = random_captcha_text()
    # 把验证码列表转为字符串
    captcha_text = ''.join(captcha_text)
    # 生成验证码
    captcha = image.generate(captcha_text)
    if not os.path.exists('captcha/images/'):
        os.makedirs('captcha/images/')
    file_name = 'captcha/images/' + captcha_text + '.jpg'
    image.write(captcha_text, file_name)  # 写进文件



三、 生成图像

在这里我们生成我们10000张图像,调用我们原来生成的函数gen_captcha_text_and_image(),写入我们的文件在里面,在这里值得注意的是我们生成的图像不是10000张,因为你随机生成可能会发生重复的图像,所以他生成的肯定没有我们定义的数量那么多。


num = 10000
if __name__ == '__main__':
    for i in range(num):
        gen_captcha_text_and_image()
        sys.stdout.write('\r>> Creating image %d/%d' % (i+1, num))
        sys.stdout.flush()
    sys.stdout.write('\n')
    sys.stdout.flush()
    print('生成完成')
四、完整实现
# 验证码生成器
from captcha.image import ImageCaptcha   # pip install captcha
import numpy as np
from PIL import Image
import random
import sys
import os

#
number = list()
for i in range(10):
    number.append(str(i))
# alphact = [a-z]
# alpharet = [A-Z]

# 生成四位验证码
def random_captcha_text(char_set=number, captcha_size=4):
    # 验证码列表
    captcha_text = []
    for i in range(captcha_size):
        # 随机选择
        c = random.choice(char_set)
        # 加入验证码列表
        captcha_text.append(c)
    return captcha_text

# 生成字符对应的验证码
def gen_captcha_text_and_image():
    image = ImageCaptcha()  # 生成验证码对象
    # 获得随机生成的验证码
    captcha_text = random_captcha_text()
    # 把验证码列表转为字符串
    captcha_text = ''.join(captcha_text)
    # 生成验证码
    captcha = image.generate(captcha_text)
    if not os.path.exists('captcha/images/'):
        os.makedirs('captcha/images/')
    file_name = 'captcha/images/' + captcha_text + '.jpg'
    image.write(captcha_text, file_name)  # 写进文件

# 数量少于10000,因为重名,生成一万张,但是会出现重名 所以不是10000张。
num = 10000
if __name__ == '__main__':
    for i in range(num):
        gen_captcha_text_and_image()
        sys.stdout.write('\r>> Creating image %d/%d' % (i+1, num))
        sys.stdout.flush()
    sys.stdout.write('\n')
    sys.stdout.flush()
    print('生成完成')




生成tfrecord文件

我们在生成数据的时候我们需要将数据转化成tfrecord文件,那就下来我们讲解的就是这个操作。



导入库
import tensorflow as tf
import os
import random
import math
import sys
from PIL import Image
import numpy as np



一、 定义数据参数

定义数据的存储位置和其他基本的参数

# 验证集数量
_NUM_TEST = 500

# 随机种子
_RANDOM_SEED = 0

# 数据集路径
DATASET_DIR = "./captcha/images/"

# tfrecord文件存放路径
TFRECORD_DIR = "./captcha/"



二、检查tfrecord文件是不是存在

查看我们的tfrecord文件是不是存在,不存在我么就需要进行生成。

def _dataset_exists(dataset_dir):
    for split_name in ['train', 'test']:
        output_filename = os.path.join(dataset_dir, split_name + '.tfrecords')
        if not tf.gfile.Exists(output_filename):
            return False
    return True


三、 获取图像文件

通过传入我们验证码图像的路径,来生成我们的文件路径的列表方便后面的读取。

def _get_filenames_and_classes(dataset_dir):
    photo_filenames = []
    for filename in os.listdir(dataset_dir):
        # 获取文件路径
        path = os.path.join(dataset_dir, filename)
        photo_filenames.append(path)
    return photo_filenames



四、 数据转换

用来准换image数据和image对应的验证码的数据

def int64_feature(values):
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))



五、定义数据

定义image数据和对应的值,使用上面提到的数据转换函数,返回结果

def image_to_tfexample(image_data, label0, label1, label2, label3):
    # Abstract base class for protocol messages.
    return tf.train.Example(features=tf.train.Features(feature={
        'image': bytes_feature(image_data),
        'label0': int64_feature(label0),
        'label1': int64_feature(label1),
        'label2': int64_feature(label2),
        'label3': int64_feature(label3),
    }))


六、转换成tfrecord格式

首先我们定义tfrecord文件的路径和名字,在后面我们遍历我们生成的文件名字的列表,打开图像文件,转换他的类型,在这个后面我们需要对数据进行resize,因为验证码生成的需要的大小是这样我们需要进行更改size,灰度化处理,转换为bytes,获取图像路径下的图像对应的数字(就是文件的文件名),在生成protocol数据类型。

def _convert_dataset(split_name, filenames, dataset_dir):
    assert split_name in ['train', 'test']
    
    with tf.Session() as sess:
        # 定义tfrecord文件的路径+名字
        output_filename = os.path.join(TFRECORD_DIR, split_name + '.tfrecords')
        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
            for i, filename in enumerate(filenames):
                try:
                    sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, len(filenames)))
                    sys.stdout.flush()

                    # 读取图片
                    image_data = Image.open(filename)
                    # 根据模型的结构resize
                    image_data = image_data.resize((224, 224))
                    # 灰度化
                    image_data = np.array(image_data.convert('L'))
                    # 将图片转化为bytes
                    image_data = image_data.tobytes()

                    # 获取label
                    labels = filename.split('/')[-1][0:4]
                    num_labels = []
                    for j in range(4):
                        num_labels.append(int(labels[j]))

                    # 生成protocol数据类型
                    example = image_to_tfexample(image_data, num_labels[0], num_labels[1], num_labels[2], num_labels[3])
                    tfrecord_writer.write(example.SerializeToString())

                except IOError as e:
                    print('Could not read:', filename)
                    print('Error:', e)
                    print('Skip it\n')
    sys.stdout.write('\n')
    sys.stdout.flush()




七、 判断是否存在并划分数据

先判断是不是文件已经有了,如果没有获取我们所有的图像,切分数据集和测试集合,我们惊醒数据的转换在生成tfrecord文件。

if _dataset_exists(TFRECORD_DIR):
    print('tfcecord文件已存在')
else:
    # 获得所有图片
    photo_filenames = _get_filenames_and_classes(DATASET_DIR)

    # 把数据切分为训练集和测试集,并打乱
    random.seed(_RANDOM_SEED)
    random.shuffle(photo_filenames)
    training_filenames = photo_filenames[_NUM_TEST:]
    testing_filenames = photo_filenames[:_NUM_TEST]

    # 数据转换
    _convert_dataset('train', training_filenames, DATASET_DIR)
    _convert_dataset('test', testing_filenames, DATASET_DIR)
    print('生成tfcecord文件')


八、 完整实现
import tensorflow as tf
import os
import random
import math
import sys
from PIL import Image
import numpy as np

# In[2]:

# 验证集数量
_NUM_TEST = 500

# 随机种子
_RANDOM_SEED = 0

# 数据集路径
DATASET_DIR = "./captcha/images/"

# tfrecord文件存放路径
TFRECORD_DIR = "./captcha/"


# 判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):
    for split_name in ['train', 'test']:
        output_filename = os.path.join(dataset_dir, split_name + '.tfrecords')
        if not tf.gfile.Exists(output_filename):
            return False
    return True


# 获取所有验证码图片
def _get_filenames_and_classes(dataset_dir):
    photo_filenames = []
    for filename in os.listdir(dataset_dir):
        # 获取文件路径
        path = os.path.join(dataset_dir, filename)
        photo_filenames.append(path)
    return photo_filenames


def int64_feature(values):
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def image_to_tfexample(image_data, label0, label1, label2, label3):
    # Abstract base class for protocol messages.
    return tf.train.Example(features=tf.train.Features(feature={
        'image': bytes_feature(image_data),
        'label0': int64_feature(label0),
        'label1': int64_feature(label1),
        'label2': int64_feature(label2),
        'label3': int64_feature(label3),
    }))


# 把数据转为TFRecord格式
def _convert_dataset(split_name, filenames, dataset_dir):
    assert split_name in ['train', 'test']

    with tf.Session() as sess:
        # 定义tfrecord文件的路径+名字
        output_filename = os.path.join(TFRECORD_DIR, split_name + '.tfrecords')
        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
            for i, filename in enumerate(filenames):
                try:
                    sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, len(filenames)))
                    sys.stdout.flush()

                    # 读取图片
                    image_data = Image.open(filename)
                    # 根据模型的结构resize
                    image_data = image_data.resize((224, 224))
                    # 灰度化
                    image_data = np.array(image_data.convert('L'))
                    # 将图片转化为bytes
                    image_data = image_data.tobytes()

                    # 获取label
                    labels = filename.split('/')[-1][0:4]
                    num_labels = []
                    for j in range(4):
                        num_labels.append(int(labels[j]))

                    # 生成protocol数据类型
                    example = image_to_tfexample(image_data, num_labels[0], num_labels[1], num_labels[2], num_labels[3])
                    tfrecord_writer.write(example.SerializeToString())

                except IOError as e:
                    print('Could not read:', filename)
                    print('Error:', e)
                    print('Skip it\n')
    sys.stdout.write('\n')
    sys.stdout.flush()


# 判断tfrecord文件是否存在
if _dataset_exists(TFRECORD_DIR):
    print('tfcecord文件已存在')
else:
    # 获得所有图片
    photo_filenames = _get_filenames_and_classes(DATASET_DIR)

    # 把数据切分为训练集和测试集,并打乱
    random.seed(_RANDOM_SEED)
    random.shuffle(photo_filenames)
    training_filenames = photo_filenames[_NUM_TEST:]
    testing_filenames = photo_filenames[:_NUM_TEST]

    # 数据转换
    _convert_dataset('train', training_filenames, DATASET_DIR)
    _convert_dataset('test', testing_filenames, DATASET_DIR)
    print('生成tfcecord文件')



说明

本博客是自己学习说明,具体教程在:https://www.bilibili.com/video/BV1kW411W7pZ?p=31

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值