目录
1、为什么使用TFRecord格式的文件
在深度学习模型训练时,通常我们使用的数据量往往都会很大,这么多数据会占用很大的磁盘空间,并且在被一个个读取的时候会很慢、很繁琐,占用大量内存空间(有的大型数据不足以一次性加载)。TFRecord格式文件内部使用了“Protocol Buffer”二进制数据编码方案,它只需占用一个内存块,只需要一次性加载一个二进制文件。当数据比较大的时候,我们可以把数据制作成多个TFRecord文件,来提高处理效率。
2、生成TFRecord文件的主要流程
1)使用 tf.python_io.TFRecordWriter 函数生成一个TFRecord的生成器
tf_filename = _get_output_filename(output_dir,name,fidx) with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer
这里的output_dir参数是生成.tfrecord文件的路径
2)使用tf.train.Example函数生成example
example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': int64_feature(shape[0]), 'image/width': int64_feature(shape[1]), 'image/channels': int64_feature(shape[2]), 'image/shape': int64_feature(shape), 'image/object/bbox/xmin': float_feature(xmin), 'image/object/bbox/xmax': float_feature(xmax), 'image/object/bbox/ymin': float_feature(ymin), 'image/object/bbox/ymax': float_feature(ymax), 'image/filename': int64_feature(xml_pic_name), 'image/object/bbox/label': int64_feature(labels), 'image/object/bbox/label_text': bytes_feature(labels_text), 'image/object/bbox/difficult': int64_feature(difficult), 'image/object/bbox/truncated': int64_feature(truncated), 'image/format': bytes_feature(image_format), 'image/encoded': bytes_feature(image_data)}))
上面写入example中的数据格式主要有三种,分别是 BytesList, FloatList,Int64List。其定义如下代码段:
def int64_feature(value): if not isinstance(value,list): value = [value] return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) def float_feature(value): if not isinstance(value,list): value = [value] return tf.train.Feature(float_list=tf.train.FloatList(value=value)) def bytes_feature(value): if not isinstance(value,list): value = [value] return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
3)通过 tfrecord_writer.write(example.SerializeToString()) 语句将数据写入.tfrecord文件
3、SSD算法中生成TFRecord格式数据主要代码
import tensorflow as tf
import os
import sys
import random
import xml.etree.ElementTree as ET
from pascalvoc_common import VOC_LABELS
from dataset_utils import int64_feature,float_feature,bytes_feature
DIRECTORY_ANNOTATIONS = 'Annotations/'
DIRECTORY_IMAGES = '/JPEGImages/'
RANDOM_SEED = 4242
SAMPLES_PER_FILES = 1
def _get_output_filename(output_dir, name, idx):
return '%s/%s_%03d.tfrecord ' % (output_dir,name,idx)
def _process_image(directory, name,i):
"""
此程序是处理图片和xml文件
:param dataset_dir: 数据集路径
:param name: 要处理的图片名
:return:
"""
filename = directory + DIRECTORY_IMAGES + name + '.jpg' #图片的名字
print(filename)
image_data = tf.gfile.FastGFile(filename, 'rb').read() #对图片进行读取
filename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml') #xml的名字
tree = ET.parse(filename) #解析xml文件
root = tree.getroot()
size = root.find('size')
print(size)
shape = [int(size.find('height').text),
int(size.find('width').text),
int(size.find('depth').text)] # 从xml中获取图片的宽、高、和深度 [512, 767, 3]
print(shape)
xml_pic_name = i+1
bboxes = []
labels = []
labels_text = [] #[b'crack', b'crack', b'crack']
difficult = []
truncated = []
all_object = root.findall('object')
for obj in all_object:
label = obj.find('name').text
labels.append(int(VOC_LABELS[label][0])) #从这里得到crac标签前面的数字1,这里1就等于crack
labels_text.append(label.encode('ascii')) #将label为crack变为[b'crack', b'crack', b'crack']
#不知到下面这两个判断是什么意思
if obj.find('difficult'):
difficult.append(int(obj.find('difficult').text))
else:
difficult.append(0)
if obj.find('truncated'):
truncated.append(int(obj.find('truncated').text))
else:
truncated.append(0)
bbox = obj.find('bndbox')
ymin = float(bbox.find('ymin').text)/shape[0] #从xml中得到框的ymin再除以图片的高度
xmin = float(bbox.find('xmin').text)/shape[1] #从xml中得到框的xmin再除以图片的高度
ymax = float(bbox.find('ymax').text)/shape[0] #从xml中得到框的ymax再除以图片的高度
xmax = float(bbox.find('xmax').text)/shape[1] #从xml中得到框的xmax再除以图片的高度
bboxes.append((ymin,xmin,ymax,xmax))
return image_data,shape,bboxes,labels,labels_text, difficult, truncated,xml_pic_name
def _convert_to_example(image_data, labels, labels_text, bboxes, shape,
difficult, truncated,xml_pic_name):
"""
:param image_data: 读取的图片
:param labels: 图片中标框的标签,这里使用数字对应表示的
:param labels_text: 标签编码转换成文本 [b'crack', b'crack', b'crack']
:param bboxes: 从xml中得到框的坐标后同图片的尺度缩放后的坐标
:param shape: 图片的高、宽、深度
:param difficult:
:param truncated:
:return:
"""
#xml_pic_name = xml_pic_name.encode()
xmin = [] #一张图片中所有框的xmin都放在这里
ymin = []
xmax = []
ymax = []
for i,b in enumerate(bboxes):
assert len(b) == 4
ymin.append(b[0])
xmin.append(b[1])
ymax.append(b[2])
xmax.append(b[3])
#[l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
#print("第%i个框" % i)
# for l,point in zip([ymin, xmin, ymax, xmax],b):
# l.append(point)
print(ymin)
print(xmin)
print(ymax)
print(xmax)
image_format = b'JPEG'
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': int64_feature(shape[0]), #'image/filename'这条语句是我
'image/width': int64_feature(shape[1]), #自己后来加上的,目的是想提取xml
'image/channels': int64_feature(shape[2]), #文件里filename这个标签,由于名
'image/shape': int64_feature(shape), #字是字符串,我在取到后用
'image/object/bbox/xmin': float_feature(xmin), #xml_pic_name.encode()进行了编
'image/object/bbox/xmax': float_feature(xmax), #码,然后这里用的类型为
'image/object/bbox/ymin': float_feature(ymin), #bytes_feature但是在读取数据时
'image/object/bbox/ymax': float_feature(ymax), #却读不出来,一直提示数据类型错
'image/filename': int64_feature(xml_pic_name), # 误,一直没有解决这个问题
'image/object/bbox/label': int64_feature(labels),
'image/object/bbox/label_text': bytes_feature(labels_text),
'image/object/bbox/difficult': int64_feature(difficult),
'image/object/bbox/truncated': int64_feature(truncated),
'image/format': bytes_feature(image_format),
'image/encoded': bytes_feature(image_data)}))
return example
def _add_to_tfrecord(dataset_dir, name, tfrecord_writer,i):
"""
从image和annotation文件中加载数据,并且将数据添加到TFRecord
Args:
:param dataset_dir: 数据集路径
:param name: 添加到TFRecord中的图片的名字
:param tfrecord_writer: 用于生成TFRecord文件的TFRecord生成器
:return:
"""
image_data, shape, bboxes, labels, labels_text, difficult, truncated,xml_pic_name = \
_process_image(dataset_dir, name,i)
example = _convert_to_example(image_data, labels, labels_text,
bboxes, shape, difficult, truncated,xml_pic_name)
tfrecord_writer.write(example.SerializeToString())
def run(dataset_dir, output_dir, name='voc_train', shuffling=False):
"""
Run the conversion operation
Args:
:param dataset_dir:
:param output_dir:
:param name:
:param shuffling:
:return:
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
path = os.path.join(dataset_dir, DIRECTORY_ANNOTATIONS)
filenames = sorted(os.listdir(path))
#print(filenames)
if shuffling:
random.seed(RANDOM_SEED)
random.shuffle(filenames)
i = 0
fidx = 0
while i < len(filenames):
#打开新的tfrecord文件
tf_filename = _get_output_filename(output_dir,name,fidx)
with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
j = 0
while i < len(filenames) and j < SAMPLES_PER_FILES: #SAMPLES_PER_FILES张图片转化成一个tfrecord文件里
sys.stdout.write('\r>> 转化图片的进度 %d/%d' % (i+1,len(filenames)))
sys.stdout.flush()
filename = filenames[i]
img_name = filename[:-4]
_add_to_tfrecord(dataset_dir, img_name, tfrecord_writer,i)
i += 1
j += 1
fidx += 1
print("所有数据全部转化成了TFRecord的格式")
xml文件
<?xml version="1.0" encoding="UTF-8"?>
-<annotation>
<folder>108MSDCF</folder>
<filename>DSC00001</filename>
<path>E:\tytd\Drone-pictures\Cracked-picture\108MSDCF\DSC00001.jpg</path>
-<source><database>Unknown</database>
</source>
-<size><width>767</width>
<height>512</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
-<object><name>crack</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
-<bndbox><xmin>485</xmin>
<ymin>29</ymin>
<xmax>523</xmax>
<ymax>45</ymax>
</bndbox>
</object>
-<object><name>crack</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
-<bndbox><xmin>529</xmin>
<ymin>169</ymin>
<xmax>577</xmax>
<ymax>186</ymax>
</bndbox>
</object>
-<object><name>crack</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
-<bndbox><xmin>494</xmin>
<ymin>207</ymin>
<xmax>573</xmax>
<ymax>231</ymax>
</bndbox>
</object>
</annotation>