tensorflow基础学习:字符数字验证码写入tfrecord文件封装成类

今天分享一下我写的一个小小程序,基本可以满足数字+字符类型字符串写入tfrecord文件。还请多多指教!

简单说明:这个是数字+字符4位验证码的tfrecord生成代码,5位,6位的可以自行修改一下,也就一点代码。我因为有点晚了就先不改了,大家加油啦。

  • 先做些准备工作。
  • 所有字符的数据集,用于将字符转化为它的下标数字。
    再存到tfrecord里面。以便于后面读取转化为one-hot编码使用。
import tensorflow as tf
import os
import random
import sys
from PIL import Image
import numpy as np

# 所有字符的数据集,用于将字符转化为它的在列表中的下标数字
# 再存到tfrecord里面。以便于后面读取转化为one-hot编码使用
char_set = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k',
            'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F',
            'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']

  • 在这里写成一个类,便于代码复用。
  • 大家可以根据需求稍作修改使用。
注意点:路径记得把反斜杠换了,如 F:\checkimages\,要换为F:/checkimages/,最后面的斜杠别少,F:/checkimages也是不可以的
class Make_Tf_Record(object):
    # 将图片数据转化为tf文件,打包成测试集和训练集

    def __init__(self, captcha_dir, tf_file_save_dir):
        self.char_set = char_set
        # 验证码的存储路径
        self.captcha_dir = captcha_dir
        # 生成的tf文件路径
        self.tf_file_save_dir = tf_file_save_dir

  • 判断保存tfrecord文件的路径里面是否已经存在tfrecord文件
  • 在后面会调用,很简单的几句代码
    def data_exist(self):
        # 判断 record 文件是否存在
       
        for split_name in ['train', 'test']:
            output_filename = os.path.join(self.tf_file_save_dir, split_name + '.tfrecords')
            if not tf.gfile.Exists(output_filename):
                return False
        return True
  • 获取所有验证码图片的具体路径
    def get_all_captcha_filename(self, captcha_dir):
    	 # captcha_dir验证码图片所在的路径
        # 获取所有验证码图片的具体路径
        captcha_filenames = []
        for filename in os.listdir(captcha_dir):
            # 获取文件路径
            path = os.path.join(captcha_dir, filename)
            captcha_filenames.append(path)
        return captcha_filenames
  • 为转化为tf文件做准备的工作,几乎都是固定的写法。
  • 下面这个是为了将图片的像素值以bytes类型存进去,
  • 也可以说是:列表形状,字符串格式。 如"[[123,123,1],[,23,4,534]]"。
  • 这里只是一个简单说明,实际存进去的还是0和1组成的二进制数据。
  • 不然也不叫bytes。读取的时候再解码一下就好,解码tensorflow都有可以调用的函数,不慌。
  • int64_feature: 就是将验证码标签的下标存进去

    def bytes_feature(self, values):
        # 用来存图片像素值
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


	    def int64_feature(self,values):
        # 判断values是否是列表或者元组,如果不是转为列表
        if not isinstance(values, (tuple, list)):
            values = [values]
        return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

  • 下面就是调用上面的2个函数,接收处理好的image数据和4个标签值,序列化一下,返回一个对象。
  • 返回值调用一下SerializeToString()就可以写进去了
    def image_to_tf_example(self, image_data, label0, label1, label2, label3):
        # 这里默认是4位的字符类型,可以自行修改
        # 传入图片的数据和标签,然后返回Example协议类型的数据
        # 全部以字符类型存进去, 分开字符存是为了多任务训练
        # print(label0, label1, label2, label3)
        
        # 先获取每个字符对应的下标
        label0 = char_set.index(label0)
        label1 = char_set.index(label1)
        label2 = char_set.index(label2)
        label3 = char_set.index(label3)
        
        # return:返回一个Example对象,后来存进去的时候直接序列化一下 .SerializeToString()
        # 这其实是一个类字典的格式,读取的时候就会发现确实是这样
        return tf.train.Example(features=tf.train.Features(feature={
            'image': self.bytes_feature(image_data),
            'label0': self.int64_feature(label0),
            'label1': self.int64_feature(label1),
            'label2': self.int64_feature(label2),
            'label3': self.int64_feature(label3),
        }))
  • 划重点:关键一步,代码长一点点。
  • 看注释好理解
    # 把数据转为TFRecord格式
    def _convert_dataset(self, split_name, filenames):
    	# 断言,其实没啥用,可以直接注释
        assert split_name in ['train', 'test']

        with tf.Session() as sess:
            # 定义tfrecord文件的路径+名字
            output_filename = os.path.join(self.tf_file_save_dir, split_name + '.tfrecords')
            # 开启一个tf文件写入器,取名为tfrecord_writer
            with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
                for i, filename in enumerate(filenames):
                    try:
                        sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, len(filenames)))
                        sys.stdout.flush()

                        # 读取图片
                        image_data = Image.open(filename)
                        # 根据模型的结构resize
                        image_data = image_data.resize((224, 224))
                        # 灰度化,并转化为一个保存着像素值的多维数组
                        image_data = np.array(image_data.convert('L'))
                        # 将图片转化为bytes
                        image_data = image_data.tobytes()

                        # 获取label,获取路径分割的数组的最后一个也就是 abcd.jpg
                        labels = filename.split('/')[-1][0:4]
                        # 获取前面4个标签值
                        num_labels = []
                        for j in range(4):
                            num_labels.append(labels[j])

                        # 调用上面的函数,可以往回看看,生成protocol数据类型
                        example = self.image_to_tf_example(image_data, num_labels[0], num_labels[1], num_labels[2],
                                                           num_labels[3])
							# 调用write方法,直接写入一个图片的文件。
							# SerializeToString在上面也提到啦,其实基本都是这么写,想多了解可以看一下函数介绍
                        tfrecord_writer.write(example.SerializeToString())

                    except IOError as e:
                        print('Could not read:', filename)
                        print('Error:', e)
                        print('Skip it\n')
        sys.stdout.write('\n')
        sys.stdout.flush()
  • 最后的一个主函数,直接创建对象后调用这个函数就可以生成tf文件
  • 其实打乱的步骤可以去掉也是可以的,
  • 原因:get_all_captcha_filename中使用的是os.listdir(),这个函数返回的文件名称列表就是乱的
    def start(self, test_num):
        # 判断tfrecord文件是否存在
        if self.data_exist():
            print('tfcecord文件已存在')
        else:
            # 获得所有图片
            captcha_filenames = self.get_all_captcha_filename(self.captcha_dir)

            # 把数据切分为训练集和测试集,并打乱
            # 随机种子设置为0
            random.seed(0)
            random.shuffle(captcha_filenames)
            training_filenames = captcha_filenames[test_num:]
            testing_filenames = captcha_filenames[:test_num]
			
            # 数据转换
            self._convert_dataset('train', training_filenames)
            self._convert_dataset('test', testing_filenames)

            print('生成tfcecord文件')
            

小结:这只是我写的一些自己以后可能会用到的东西顺便分享一下,喜欢的化可以关注一下,以后会不断得分享python各个方向的文章。爬虫,数据分析,web,数据挖掘。大家早透啦!!!

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ziaiyu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值