在深度学习训练中,如果训练数据较小可以使用feed_dict 方式喂数据,但如果数据量较大,一般采用tensorflow 自己的数据格式tfrecord,一个是避免内存不足以存储所有训练数据,另一个是对数据的读取进行提速,因为tfrecord格式的数据读取时异步的。
下面用PASCAL VOC数据集格式进行转换
# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
import sys
import math
import os
sys.path.append('../../')
import xml.etree.cElementTree as ET
import numpy as np
import tensorflow as tf
import glob
import cv2
tf.app.flags.DEFINE_string('VOC_dir', 'C:\\Elag\\data\\tianchi\\ICPR\\test\\', 'Voc dir')
tf.app.flags.DEFINE_string('xml_dir', 'Annotations', 'xml dir')
tf.app.flags.DEFINE_string('image_dir', 'image', 'image dir')
tf.app.flags.DEFINE_string('save_name', 'train', 'save name')
tf.app.flags.DEFINE_string('save_dir', './', 'save name')
tf.app.flags.DEFINE_string('img_format', '.jpg', 'format of image')
tf.app.flags.DEFINE_string('dataset', 'pascal', 'dataset')
FLAGS = tf.app.flags.FLAGS
NAME_LABEL_MAP = {
'back_ground': 0,
'aeroplane': 1,
'bicycle': 2,
'bird': 3,
'boat': 4,
'bottle': 5,
'bus': 6,
'car': 7,
'cat': 8,
'chair': 9,
'cow': 10,
'diningtable': 11,
'dog': 12,
'horse': 13,
'motorbike': 14,
'person': 15,
'pottedplant': 16,
'sheep': 17,
'sofa': 18,
'train': 19,
'tvmonitor': 20
}
NAME_LABEL_MAP = {
'back_ground': 0,
'text': 1
}
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def view_bar(message, num, total):
rate = num / total
rate_num = int(rate * 40)
rate_nums = math.ceil(rate * 100)
r = '\r%s:[%s%s]%d%%\t%d/%d' % (message, ">" * rate_num, " " * (40 - rate_num), rate_nums, num, total,)
sys.stdout.write(r)
sys.stdout.flush()
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def read_xml_gtbox_and_label(xml_path):
"""
:param xml_path: the path of voc xml
:return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 9],
and has [x1, y1, x2, y2, x3, y3, x4, y4, label] in a per row
"""
tree = ET.parse(xml_path)
root = tree.getroot()
img_width = None
img_height = None
box_list = []
for child_of_root in root:
# if child_of_root.tag == 'filename':
# assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \
# + FLAGS.img_format, 'xml_name and img_name cannot match'
if child_of_root.tag == 'filename':
print(child_of_root.text)
if child_of_root.tag == 'size':
for child_item in child_of_root:
if child_item.tag == 'width':
img_width = int(child_item.text)
if child_item.tag == 'height':
img_height = int(child_item.text)
if child_of_root.tag == 'object':
label = None
for child_item in child_of_root:
if child_item.tag == 'name':
label = NAME_LABEL_MAP[child_item.text]
if child_item.tag == 'bndbox':
tmp_box = []
for node in child_item:
tmp_box.append(int(float(node.text)))
assert label is not None, 'label is none, error'
tmp_box.append(label)
box_list.append(tmp_box)
gtbox_label = np.array(box_list, dtype=np.int32)
return img_height, img_width, gtbox_label
def convert_pascal_to_tfrecord():
xml_path = FLAGS.VOC_dir + FLAGS.xml_dir
image_path = FLAGS.VOC_dir + FLAGS.image_dir
save_path = FLAGS.save_dir + FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord'
mkdir(FLAGS.save_dir)
writer = tf.python_io.TFRecordWriter(path=save_path)
for count, xml in enumerate(glob.glob(xml_path + '/*.xml')):
# to avoid path error in different development platform
xml = xml.replace('\\', '/')
img_name = xml.split('/')[-1].split('.')[0] + FLAGS.img_format
img_path = image_path + '/' + img_name
if not os.path.exists(img_path):
print('{} is not exist!'.format(img_path))
continue
img_height, img_width, gtbox_label = read_xml_gtbox_and_label(xml)
# img = np.array(Image.open(img_path))
img = cv2.imread(img_path)
try:
feature = tf.train.Features(feature={
# do not need encode() in linux
'img_name': _bytes_feature(img_name.encode()),
# 'img_name': _bytes_feature(img_name),
'img_height': _int64_feature(img_height),
'img_width': _int64_feature(img_width),
'img': _bytes_feature(img.tostring()),
'gtboxes_and_label': _bytes_feature(gtbox_label.tostring()),
'num_objects': _int64_feature(gtbox_label.shape[0])
})
except:
print(img_path + " error")
example = tf.train.Example(features=feature)
writer.write(example.SerializeToString())
view_bar('Conversion progress', count + 1, len(glob.glob(xml_path + '/*.xml')))
print('\nConversion is complete!')
if __name__ == '__main__':
convert_pascal_to_tfrecord()
读取tfrecord数据
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import tensorflow as tf
import os
def read_single_example_and_decode(filename_queue):
# tfrecord_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
# reader = tf.TFRecordReader(options=tfrecord_options)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized=serialized_example,
features={
'img_name': tf.FixedLenFeature([], tf.string),
'img_height': tf.FixedLenFeature([], tf.int64),
'img_width': tf.FixedLenFeature([], tf.int64),
'img': tf.FixedLenFeature([], tf.string),
'gtboxes_and_label': tf.FixedLenFeature([], tf.string),
'num_objects': tf.FixedLenFeature([], tf.int64)
}
)
img_name = features['img_name']
img_height = tf.cast(features['img_height'], tf.int32)
img_width = tf.cast(features['img_width'], tf.int32)
img = tf.decode_raw(features['img'], tf.uint8)
img = tf.reshape(img, shape=[img_height, img_width, 3])
gtboxes_and_label = tf.decode_raw(features['gtboxes_and_label'], tf.int32)
gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 9])
num_objects = tf.cast(features['num_objects'], tf.int32)
return img_name, img, gtboxes_and_label, num_objects
def read_and_prepocess_single_img(filename_queue, is_training):
img_name, img, gtboxes_and_label, num_objects = read_single_example_and_decode(filename_queue)
img = tf.cast(img, tf.float32)
img = img - tf.constant([103.939, 116.779, 123.68])
return img_name, img, gtboxes_and_label, num_objects
def next_batch(dataset_name, batch_size, is_training):
if dataset_name not in ['ship', 'spacenet', 'text', 'pascal', 'coco']:
raise ValueError('dataSet name must be in pascal or coco')
if is_training:
pattern = os.path.join('../data/tfrecords', dataset_name + '_train*')
else:
pattern = os.path.join('../data/tfrecords', dataset_name + '_test*')
print('tfrecord path is -->', os.path.abspath(pattern))
filename_tensorlist = tf.train.match_filenames_once(pattern)
filename_queue = tf.train.string_input_producer(filename_tensorlist)
img_name, img, gtboxes_and_label, num_obs = read_and_prepocess_single_img(filename_queue,
is_training=is_training)
img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch = \
tf.train.batch(
[img_name, img, gtboxes_and_label, num_obs],
batch_size=batch_size,
capacity=100,
num_threads=16,
dynamic_pad=True)
return img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch