TFRecords 作为TensorFlow标准支持格式,将所有信息(包括图片信息)写入到一个tfrecords文件中,便于管理数据。TFRecords 是二进制文件,有特定的写入和读取方式。
需要将下图中的image图片文件和anno.txt文件生成为tfrecords格式。
anno.txt中标注数据如下:
E:/anno/image/Drivingdecordr_01.jpg 1 648 512 793 549
E:/anno/image/Drivingdecordr_02.jpg 1 1330 446 1388 477
E:/anno/image/Drivingdecordr_03.jpg 1 26 509 95 543
E:/anno/image/Drivingdecordr_04.jpg 1 1430 437 1503 466
E:/anno/image/Drivingdecordr_05.jpg 1 1582 521 1714 567
E:/anno/image/Drivingdecordr_06.jpg 1 888 502 1053 547
E:/anno/image/Drivingdecordr_07.jpg 1 1419 453 1493 478
E:/anno/image/Drivingdecordr_08.jpg 1 31 495 102 530
E:/anno/image/Drivingdecordr_09.jpg 1 862 471 989 508
以下代码中trans2tfrecords()将标注文件转换为tfrecords格式数据,read_tfrecords()为训练模型读取tfrecords格式数据。
# -*- coding: utf-8 -*-
import tensorflow as tf
import cv2
import sys
import random
def get_data(data_dir):
imagelist = open(data_dir, 'r')
dataset = []
for line in imagelist.readlines():
info = line.strip().split(' ')
data_example = dict()
box = dict()
data_example['filename'] = info[0]
data_example['label'] = int(info[1])
box['xmin'] = float(info[2])
box['ymin'] = float(info[3])
box['xmax'] = float(info[4])
box['ymax'] = float(info[5])
data_example['box'] = box
dataset.append(data_example)
return dataset
def _int64_feature(value):
"""Wrapper for insert int64 feature into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _float_feature(value):
"""Wrapper for insert float features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _bytes_feature(value):
"""Wrapper for insert bytes features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def trans2tfrecords(tf_filename, dataset):
with tf.python_io.TFRecordWriter(tf_filename) as writer:
for i, image_example in enumerate(dataset):
sys.stdout.write('\r>> Converting image %d/%d\n' % (i + 1, len(dataset)))
#sys.stdout.flush()
filename = image_example['filename']
image_data = extract(filename)
class_label = image_example['label']
box = image_example['box']
roi = [box['xmin'], box['ymin'], box['xmax'], box['ymax']]
#tf.train.Example协议内存块,包含字段Features,Features中包含Feature的字典
example = tf.train.Example(features = tf.train.Features(feature={
'image': _bytes_feature(image_data),
'label': _int64_feature(class_label),
'roi': _float_feature(roi)
}))
writer.write(example.SerializeToString())
def extract(filename):
image = cv2.imread(filename)
image_data = image.tostring() #将numpy类转化为string类
return image_data
def read_tfrecords(tfrecord_file):
filename_queue = tf.train.string_input_producer([tfrecord_file],shuffle=True) #生成队列,并随机打乱顺序
reader = tf.TFRecordReader() #文件读取器
_, serialized_example = reader.read(filename_queue)
image_features = tf.parse_single_example( #解析器
serialized_example,
features = {
'image': tf.FixedLenFeature([], tf.string),
'label':tf.FixedLenFeature([], tf.int64),
'roi': tf.FixedLenFeature([4], tf.float32)
})
image = tf.decode_raw(image_features['image'], tf.uint8) #解码器
image = tf.reshape(image, [128, 128, 3])
image = (tf.cast(image, tf.float32)-127.5) / 128
label = tf.cast(image_features['label'], tf.float32)
roi = tf.cast(image_features['roi'],tf.float32)
return image, label, roi
if __name__ == '__main__':
data_dir = 'E:/anno/anno.txt'
imagelist = open(data_dir, 'r')
output_dir = 'E:/anno'
tf_filename = '%s/train_data.tfrecord' %output_dir
shuffling = 'True'
dataset = get_data(data_dir)
if shuffling:
random.shuffle(dataset)
trans2tfrecords(tf_filename, dataset)
print('Finished converting dataset!')
image, label, roi = read_tfrecords(tf_filename)