TFrecord生成方法 ——更快速的方法
-
TFRecord格式
Tensorflow支持的一种数据格式,其内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。
-
TFRecord变量类型有三种:
# 用于分类标签存储 tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 用于图片存储 tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 用于浮点型标签存储 tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
注意:value的值为list类型
-
第一步:将数据转化为tf.train.Feature格式。
- int64和float可直接转化,例如:
tf.train.Feature(int64_list=tf.train.Int64List(value=[1])) tf.train.Feature(float_list=tf.train.FloatList(value=[1.0,2.0,3.0]))
- 转化为bytes时,有两种方法:
-
利用tensorflow转化。tensorflow读取的是最原始的没有经过解码的图像,而通常我们可视化的图像都是解码的形式,所以要对读取后的图像进行解码,然后再进行resize等操作,最后将图片转换成string类型。
tf.string:Variable length byte arrays. Each element of a Tensor is a byte array。# 读取最原始的没有经过解码的图像,byte类型 with tf.gfile.GFile(image_path, 'rb') as fid: image_raw_data = fid.read() # 解码,可得到像素值 image = tf.image.decode_jpeg(image_raw_data) # resize image = tf.image.resize(image, [resize, resize]) with tf.Session() as sess: image = sess.run(image) # 将图片转换成string image_data = image.tostring() tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
-
一种更快的方法:
不利用tensorflow,直接利用PIL或CV库操作图像,然后将图像转为bytes。image = cv2.resize(image,(resize, resize)) image_data = image.tobytes()
- int64和float可直接转化,例如:
-
第二步:创建Example对象,并将Feature一一对应填充进去
tf_example = tf.train.Example(features=tf.train.Features(feature={ 'img': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])), 'float_label': tf.train.Feature( float_list=tf.train.FloatList(value=[1.0,2.0,3.0])), 'label': tf.train.Feature( float_list=tf.train.FloatList(value=[0])) }))
-
第三步:存入tfrecord
Writern = tf.python_io.TFRecordWriter(record_path) # 调用 SerializetoString() 进行序列化 Writern.write(tf_example.SerializeToString())