import tensorflow as tf
import os
"""
注意: 这个API处理的图片格式是:图片的名字就是标签的内容, 例如一张图片叫做'2a3w.jpg', 2a3w即为标签值.
将图片, 标签转换成TFrecord格式, 其中标签已经由[2a3w] >> [12, 23, 45, 46]
"""
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("pic_dir", "../pic_source/train6000/", "源图片的路径")
tf.flags.DEFINE_string("letter", "abcdefghijklmnopqrstuvwxyz1234567890", "验证码字符的种类")
tf.flags.DEFINE_string("tfrecords_dir", "../pic_source/train6000_tfrecord/train6000.tfrecords", "验证码tfrecords文件")
tf.flags.DEFINE_integer("image_num",367,"总共有多少张图片")
def dealWithLabel(label_str):
"""
:param label_str: ['2a2s', '2w3e', '4e5r', '3e5r', '4r5g'....]
:return:
"""
# 构建字符索引 {0:'a', 1:'b'......}
num2letter = dict(enumerate(list(FLAGS.letter)))
# 键值对翻转 {'a':0, 'b':1......}
letter2num = dict(zip(num2letter.values(),num2letter.keys()))
print(letter2num)
# 构建标签的列表
array = []
for lablex in label_str:
letter_list = []
for letter in list(str(lablex)): # '2a2s'>>[2,a,2,s]
letter2_num = letter2num[str(letter)] #a>>2
letter_list.append(letter2_num)
array.append(letter_list)
print(array)
# 将array转换成tensor类型
lable = tf.constant(array)
return lable
def get_image(file_name):
"""
获取验证码的图片数据,以及标签
注意: 这个API处理的图片格式是:图片的名字就是标签的内容,例如一张图片叫做 '2a3w.jpg',2a3w即为标签值.
:return: image_batch,lable_batch
"""
# 遍历获取标签名字 ['2a2s', '2w3e', '4e5r', '3e5r', '4r5g'....]
# file_name = [str(lable).split(".")[0] for lable in os.listdir(path=FLAGS.pic_dir)]
# 构造路径+文件 ['../pic_source/train6000/2a2s.jpg', '', '', '', '', ''.....]
file_list = [os.path.join(FLAGS.pic_dir, labl + ".jpg") for labl in file_name]
# 构造文件队列
image_queue = tf.train.string_input_producer(file_list, shuffle=False)
# 构造阅读器
image_reader = tf.WholeFileReader()
# 读取图片内容
key, value = image_reader.read(image_queue)
# 解码图片数据
image = tf.image.decode_jpeg(value)
image.set_shape([60, 180, 3]) # 必须按照height高 * weight宽 * 通道数 的顺序来写
# 批处理数据 [367, 60, 180, 3]
image_batch = tf.train.batch([image], batch_size=len(file_name),num_threads=1, capacity=len(file_name))
# lable_batch = tf.train.batch(lable_queue, batch_size=len(file_name), num_threads=1, capacity=len(file_name))
print(image_batch)
# print(lable_batch)
return image_batch
def write2tfrecords(image_batch, lable_batch):
"""
将图片内容和标签写入到tfrecords文件当中
:param image_batch:特征值
:param lable_batch:标签值
:return:
"""
# 类型转换
lable_batch = tf.cast(lable_batch,tf.uint8)
print(lable_batch)
# 建立TFRecords存储器
writer = tf.python_io.TFRecordWriter(FLAGS.tfrecords_dir)
# 循环将每一个图片上的数据构造example协议块,序列化后写入
for i in range(FLAGS.image_num):
# 取出第i个图片数据,转换相应类型,图片的特征值要转换成字符串形式
image_string = image_batch[i].eval().tostring()
# 标签值,转换成整型
label_string = lable_batch[i].eval().tostring()
# 构造协议块
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_string])),
"label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_string]))
}))
writer.write(example.SerializeToString())
# 关闭文件
writer.close()
return None
def generateTFrecord(file_name):
"""
注意: 这个API处理的图片格式是:图片的名字就是标签的内容,例如一张图片叫做 '2a3w.jpg',2a3w即为标签值.
将图片,标签转换成TFrecord格式,其中标签已经由 [2a3w]>>[12,23,45,46]
:param file_name: []
:return:
"""
# 获取验证码中的图片
image_batch = get_image(file_name)
# 获取标签数据
label_batch = dealWithLabel(label_str=file_name)
print(image_batch, label_batch)
with tf.Session() as sess:
coord = tf.train.Coordinator() #构建队列
threads = tf.train.start_queue_runners(sess=sess, coord=coord) #构建线程
# 将图片数据和内容写入到tfrecords文件当中
write2tfrecords(image_batch, label_batch)
coord.request_stop()
coord.join(threads)
print('生成TFRecord文件成功.............................................................')
return None
if __name__ == '__main__':
file_name = [str(lable).split(".")[0] for lable in os.listdir(path=FLAGS.pic_dir)]
generateTFrecord(file_name=file_name)
将验证码图像转换为TFRecord文件
最新推荐文章于 2020-03-18 12:45:35 发布