今天分享一下我写的一个小小程序,基本可以满足数字+字符类型字符串写入tfrecord文件。还请多多指教!
简单说明:这个是数字+字符4位验证码的tfrecord生成代码,5位,6位的可以自行修改一下,也就一点代码。我因为有点晚了就先不改了,大家加油啦。
先做些准备工作。 所有字符的数据集,用于将字符转化为它的下标数字。 再存到tfrecord里面。以便于后面读取转化为one-hot编码使用。
import tensorflow as tf
import os
import random
import sys
from PIL import Image
import numpy as np
char_set = [ '0' , '1' , '2' , '3' , '4' , '5' , '6' , '7' , '8' , '9' , 'a' , 'b' , 'c' , 'd' , 'e' , 'f' , 'g' , 'h' , 'i' , 'j' , 'k' ,
'l' , 'm' , 'n' , 'o' , 'p' , 'q' , 'r' , 's' , 't' , 'u' , 'v' , 'w' , 'x' , 'y' , 'z' , 'A' , 'B' , 'C' , 'D' , 'E' , 'F' ,
'G' , 'H' , 'I' , 'J' , 'K' , 'L' , 'M' , 'N' , 'O' , 'P' , 'Q' , 'R' , 'S' , 'T' , 'U' , 'V' , 'W' , 'X' , 'Y' , 'Z' ]
在这里写成一个类,便于代码复用。 大家可以根据需求稍作修改使用。
注意点:路径记得把反斜杠换了,如 F:\checkimages\,要换为F:/checkimages/,最后面的斜杠别少,F:/checkimages也是不可以的
class Make_Tf_Record ( object ) :
def __init__ ( self, captcha_dir, tf_file_save_dir) :
self. char_set = char_set
self. captcha_dir = captcha_dir
self. tf_file_save_dir = tf_file_save_dir
判断保存tfrecord文件的路径里面是否已经存在tfrecord文件 在后面会调用,很简单的几句代码
def data_exist ( self) :
for split_name in [ 'train' , 'test' ] :
output_filename = os. path. join( self. tf_file_save_dir, split_name + '.tfrecords' )
if not tf. gfile. Exists( output_filename) :
return False
return True
def get_all_captcha_filename ( self, captcha_dir) :
captcha_filenames = [ ]
for filename in os. listdir( captcha_dir) :
path = os. path. join( captcha_dir, filename)
captcha_filenames. append( path)
return captcha_filenames
为转化为tf文件做准备的工作,几乎都是固定的写法。 下面这个是为了将图片的像素值以bytes类型存进去, 也可以说是:列表形状,字符串格式。 如"[[123,123,1],[,23,4,534]]"。 这里只是一个简单说明,实际存进去的还是0和1组成的二进制数据。 不然也不叫bytes。读取的时候再解码一下就好,解码tensorflow都有可以调用的函数,不慌。 int64_feature: 就是将验证码标签的下标存进去
def bytes_feature ( self, values) :
return tf. train. Feature( bytes_list= tf. train. BytesList( value= [ values] ) )
def int64_feature ( self, values) :
if not isinstance ( values, ( tuple , list ) ) :
values = [ values]
return tf. train. Feature( int64_list= tf. train. Int64List( value= values) )
下面就是调用上面的2个函数,接收处理好的image数据和4个标签值,序列化一下,返回一个对象。 返回值调用一下SerializeToString()就可以写进去了
def image_to_tf_example ( self, image_data, label0, label1, label2, label3) :
label0 = char_set. index( label0)
label1 = char_set. index( label1)
label2 = char_set. index( label2)
label3 = char_set. index( label3)
return tf. train. Example( features= tf. train. Features( feature= {
'image' : self. bytes_feature( image_data) ,
'label0' : self. int64_feature( label0) ,
'label1' : self. int64_feature( label1) ,
'label2' : self. int64_feature( label2) ,
'label3' : self. int64_feature( label3) ,
} ) )
def _convert_dataset ( self, split_name, filenames) :
assert split_name in [ 'train' , 'test' ]
with tf. Session( ) as sess:
output_filename = os. path. join( self. tf_file_save_dir, split_name + '.tfrecords' )
with tf. python_io. TFRecordWriter( output_filename) as tfrecord_writer:
for i, filename in enumerate ( filenames) :
try :
sys. stdout. write( '\r>> Converting image %d/%d' % ( i + 1 , len ( filenames) ) )
sys. stdout. flush( )
image_data = Image. open ( filename)
image_data = image_data. resize( ( 224 , 224 ) )
image_data = np. array( image_data. convert( 'L' ) )
image_data = image_data. tobytes( )
labels = filename. split( '/' ) [ - 1 ] [ 0 : 4 ]
num_labels = [ ]
for j in range ( 4 ) :
num_labels. append( labels[ j] )
example = self. image_to_tf_example( image_data, num_labels[ 0 ] , num_labels[ 1 ] , num_labels[ 2 ] ,
num_labels[ 3 ] )
tfrecord_writer. write( example. SerializeToString( ) )
except IOError as e:
print ( 'Could not read:' , filename)
print ( 'Error:' , e)
print ( 'Skip it\n' )
sys. stdout. write( '\n' )
sys. stdout. flush( )
最后的一个主函数,直接创建对象后调用这个函数就可以生成tf文件 其实打乱的步骤可以去掉也是可以的, 原因:get_all_captcha_filename中使用的是os.listdir(),这个函数返回的文件名称列表就是乱的
def start ( self, test_num) :
if self. data_exist( ) :
print ( 'tfcecord文件已存在' )
else :
captcha_filenames = self. get_all_captcha_filename( self. captcha_dir)
random. seed( 0 )
random. shuffle( captcha_filenames)
training_filenames = captcha_filenames[ test_num: ]
testing_filenames = captcha_filenames[ : test_num]
self. _convert_dataset( 'train' , training_filenames)
self. _convert_dataset( 'test' , testing_filenames)
print ( '生成tfcecord文件' )
小结:这只是我写的一些自己以后可能会用到的东西顺便分享一下,喜欢的化可以关注一下,以后会不断得分享python各个方向的文章。爬虫,数据分析,web,数据挖掘。大家早透啦!!!