tfrecord格式是tensorflow官方推荐的数据格式,把数据、标签进行统一的存储
tfrecord文件包含了tf.train.Example 协议缓冲区(protocol buffer,协议缓冲区包含了特征 Features), 能让tensorflow更好的利用内存。
把某个文件夹的图片和标签存入同一个tfrecord文件,代码如下:
def write(input_file, output_file):
writer = tf.python_io.TFRecordWriter(output_file) #定义writer,传入目标文件路径
path = input_file
file_names = [f for f in os.listdir(path) if f.endswith('.jpg')] #获取待存文件路径
for file_name in file_names:
img = cv2.imread(path + file_name)
raw_img = img.tobytes() #需要把图片文件转化成bytes形式(二进制比特流)
# 把数据合并成feature,注意这里的"value="后面一定要是一个"[]"形式的列表,否则读取的时候会出现can't parse的情况
features = tf.train.Features(feature={'img_name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[file_name])),
'raw_img': tf.train.Feature(bytes_list=tf.train.BytesList(value=[raw_img]))})
#把features存入example
example = tf.train.Example(features=features)
#example序列化,并写入文件
writer.write(example.SerializeToString())
writer.close()
input_file = 'samples/'
output_file = 'samples.tfrecords'
write(input_file, output_file)
print 'Write tfrecords: %s done' %output_file
基本步骤:
- 读取待存文件内容,转化为bytes形式
- 数据合并成tf.train.Features(类似dict形式)
- 把features存入一个tf.train.Example
- 把example序列化,并写入文件
写入文件的实际上是若干个example。
其中,tf.train.Features的bytes_list支持的类型有三种:tf.train.ByteList、tf.train.FloatList、tf.train.Int64List,形式如下:
tf.train.Features(
feature={
'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),
#'label': tf.train.Feature(float_list = tf.train.FloatList(value=[i])),
'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
}))
读取文件,代码如下:
def read_and_decode(file_name):
filename_queue = tf.train.string_input_producer([file_name])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={"img_name": tf.FixedLenFeature([], tf.string),
"raw_img": tf.FixedLenFeature([], tf.string)})
img_name = features["img_name"]
image = tf.decode_raw(features['raw_img'], tf.uint8)
image = tf.reshape(image, [256, 256, 3])
return img_name, image
path = 'samples.tfrecords'
with tf.Session() as sess:
img_name, img = read_and_decode(path)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
for i in range(n):
name, image = sess.run([img_name, img])
print name
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
except tf.errors.OutOfRangeError:
print 'Done training -- epoch limit reached'
finally:
coord.request_stop()
coord.join(threads)
上述读取过程结合了文件队列,基本过程如下:
- 定义文件名队列ft.train.string_input_producer
- 定义tf.TFRecordReader
- 读取序列化的example
- 调用tf.tf.parse_single_example解析example,得到features
- 从features获取具体的数据,如果是图像,进行解码和reshape(还可以进行相关的预处理)
上述1~5步是读取一个example的“graph”,在实际使用时,先定义好graph,然后start_queue_runners(注意先后顺序,否则进程将阻塞),再根据需要,循环读取数据。