将验证码图像转换为TFRecord文件

import tensorflow as tf
import os

"""
注意: 这个API处理的图片格式是:图片的名字就是标签的内容, 例如一张图片叫做'2a3w.jpg', 2a3w即为标签值.
将图片, 标签转换成TFrecord格式, 其中标签已经由[2a3w] >> [12, 23, 45, 46]
"""

FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("pic_dir", "../pic_source/train6000/", "源图片的路径")
tf.flags.DEFINE_string("letter", "abcdefghijklmnopqrstuvwxyz1234567890", "验证码字符的种类")
tf.flags.DEFINE_string("tfrecords_dir", "../pic_source/train6000_tfrecord/train6000.tfrecords", "验证码tfrecords文件")
tf.flags.DEFINE_integer("image_num",367,"总共有多少张图片")

def dealWithLabel(label_str):
    """

    :param label_str: ['2a2s', '2w3e', '4e5r', '3e5r', '4r5g'....]
    :return:
    """
    # 构建字符索引 {0:'a', 1:'b'......}
    num2letter = dict(enumerate(list(FLAGS.letter)))
    # 键值对翻转 {'a':0, 'b':1......}
    letter2num = dict(zip(num2letter.values(),num2letter.keys()))
    print(letter2num)

    # 构建标签的列表
    array = []
    for lablex in label_str:

        letter_list = []

        for letter in list(str(lablex)): # '2a2s'>>[2,a,2,s]
            letter2_num = letter2num[str(letter)] #a>>2
            letter_list.append(letter2_num)

        array.append(letter_list)
    print(array)

    # 将array转换成tensor类型
    lable = tf.constant(array)

    return lable

def get_image(file_name):
    """
    获取验证码的图片数据,以及标签
    注意: 这个API处理的图片格式是:图片的名字就是标签的内容,例如一张图片叫做 '2a3w.jpg',2a3w即为标签值.
    :return: image_batch,lable_batch
    """

    # 遍历获取标签名字 ['2a2s', '2w3e', '4e5r', '3e5r', '4r5g'....]
    # file_name = [str(lable).split(".")[0] for lable in os.listdir(path=FLAGS.pic_dir)]
    # 构造路径+文件 ['../pic_source/train6000/2a2s.jpg', '', '', '', '', ''.....]
    file_list = [os.path.join(FLAGS.pic_dir, labl + ".jpg") for labl in file_name]
    # 构造文件队列
    image_queue = tf.train.string_input_producer(file_list, shuffle=False)
    # 构造阅读器
    image_reader = tf.WholeFileReader()
    # 读取图片内容
    key, value = image_reader.read(image_queue)
    # 解码图片数据
    image = tf.image.decode_jpeg(value)
    image.set_shape([60, 180, 3]) # 必须按照height高 * weight宽 * 通道数 的顺序来写
    # 批处理数据 [367, 60, 180, 3]
    image_batch = tf.train.batch([image], batch_size=len(file_name),num_threads=1, capacity=len(file_name))
    # lable_batch = tf.train.batch(lable_queue, batch_size=len(file_name), num_threads=1, capacity=len(file_name))

    print(image_batch)
    # print(lable_batch)
    return image_batch

def write2tfrecords(image_batch, lable_batch):
    """
     将图片内容和标签写入到tfrecords文件当中
    :param image_batch:特征值
    :param lable_batch:标签值
    :return:
    """
    # 类型转换
    lable_batch = tf.cast(lable_batch,tf.uint8)
    print(lable_batch)

    # 建立TFRecords存储器
    writer = tf.python_io.TFRecordWriter(FLAGS.tfrecords_dir)

    # 循环将每一个图片上的数据构造example协议块,序列化后写入
    for i in range(FLAGS.image_num):
        # 取出第i个图片数据,转换相应类型,图片的特征值要转换成字符串形式
        image_string = image_batch[i].eval().tostring()

        # 标签值,转换成整型
        label_string = lable_batch[i].eval().tostring()

        # 构造协议块
        example = tf.train.Example(features=tf.train.Features(feature={
            "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_string])),
            "label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_string]))
        }))

        writer.write(example.SerializeToString())
    # 关闭文件
    writer.close()

    return None

def generateTFrecord(file_name):
    """
    注意: 这个API处理的图片格式是:图片的名字就是标签的内容,例如一张图片叫做 '2a3w.jpg',2a3w即为标签值.
    将图片,标签转换成TFrecord格式,其中标签已经由 [2a3w]>>[12,23,45,46]
    :param file_name: []
    :return:
    """
    # 获取验证码中的图片
    image_batch = get_image(file_name)
    # 获取标签数据
    label_batch = dealWithLabel(label_str=file_name)
    print(image_batch, label_batch)

    with tf.Session() as sess:
        coord = tf.train.Coordinator() #构建队列
        threads = tf.train.start_queue_runners(sess=sess, coord=coord) #构建线程

        # 将图片数据和内容写入到tfrecords文件当中
        write2tfrecords(image_batch, label_batch)
        coord.request_stop()
        coord.join(threads)
        print('生成TFRecord文件成功.............................................................')
    return None

if __name__ == '__main__':
    file_name = [str(lable).split(".")[0] for lable in os.listdir(path=FLAGS.pic_dir)]
    generateTFrecord(file_name=file_name)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值