Tensorflow2 voc2012-tfrecord文件读写
TFRcord是Tensorflow内置的文件格式,是一种二进制文件,有时候我们会利用同一数据集训练不同的网络框架,对于较大数据集不断地复制移动读取较为不便,而TFRecord文件可将二进制数据和标签存储到同一文件中,可以更好地利用内存,方便复制和移动。
1 读取VOC-xml文件
xml文件数据存储的格式类似于数据结构中的树结构,anaotation为根节点。在voc2012中主要获取fIlename,width,height,和object(图像中的目标)下的name(目标名称)和bndbox(框的位置)。
<annotation>
<folder>VOC2012</folder>
<filename>2007_000123.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
</source>
<size>
<width>500</width>
<height>375</height>
<depth>3</depth>
</size>
<segmented>1</segmented>
<object>
<name>train</name>
<pose>Unspecified</pose>
<truncated>1</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>1</xmin>
<ymin>26</ymin>
<xmax>358</xmax>
<ymax>340</ymax>
</bndbox>
</object>
</annotation>
import xml.dom.minidom as xdom
classes_names = open('model_data/voc_classes.txt', 'r').readlines()
voc_classes = {classes_names[i].strip(): i+1 for i in range(len(classes_names))}
print(voc_classes)
def Prase_Singel_xml(xml_path):
DOMTree = xdom.parse(xml_path)
RootNode = DOMTree.documentElement
#获取图像名称
image_name = RootNode.getElementsByTagName("filename")[0].childNodes[0].data
#获取图像宽和高
size = RootNode.getElementsByTagName("size")
image_height = int(size[0].getElementsByTagName("height")[0].childNodes[0].data)
image_width = int(size[0].getElementsByTagName("width")[0].childNodes[0].data)
#获取图像中目标对象的名称及位置
all_obj = RootNode.getElementsByTagName("object")
bndbox_lable_dic = []
for one_obj in all_obj:
obj_name = one_obj.getElementsByTagName("name")[0].childNodes[0].data
obj_label = voc_classes[obj_name]
bndbox = one_obj.getElementsByTagName("bndbox")
#获取目标的左上右下的位置
xmin = int(bndbox[0].getElementsByTagName("xmin")[0].childNodes[0].data)
ymin = int(bndbox[0].getElementsByTagName("ymin")[0].childNodes[0].data)
xmax = int(bndbox[0].getElementsByTagName("xmax")[0].childNodes[0].data)
ymax = int(bndbox[0].getElementsByTagName("ymax")[0].childNodes[0].data)
bndbox_lable_dic.append([xmin, ymin, xmax, ymax, obj_label])
return image_name, image_width, image_height, bndbox_lable_dic
if __name__=='__main__':
print(Prase_Singel_xml('VOCdevkit/VOC2012/Annotations/2007_000027.xml'))
2 写入tfrecords文件
import tensorflow as tf
import glob
import os
from read_xml import Prase_Singel_xml
def write_to_tfrecord(all_xml_path, tfrecord_path, voc_img_path):
writer = tf.io.TFRecordWriter(tfrecord_path)
for i, single_xml_path in enumerate(all_xml_path):
image_name, image_width, image_height, bndbox_lable_dic = Prase_Singel_xml(single_xml_path)
sigle_img_path = os.path.join(voc_img_path, image_name)
image_data = open(sigle_img_path, 'rb').read()
xmin = []
ymin = []
xmax = []
ymax = []
obj_label = []
for j in range(len(bndbox_lable_dic)):
xmin.append(bndbox_lable_dic[j][0])
ymin.append(bndbox_lable_dic[j][1])
xmax.append(bndbox_lable_dic[j][2])
ymax.append(bndbox_lable_dic[j][3])
obj_label.append(bndbox_lable_dic[j][4])
feature = {
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])),
'width': tf.train.Feature(float_list=tf.train.FloatList(value=[image_width])),
'height': tf.train.Feature(float_list=tf.train.FloatList(value=[image_height])),
'xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmin)),
'ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymin)),
'xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmax)),
'ymax': tf.train.Feature(float_list=tf.train.FloatList(value=ymax)),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=obj_label))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
print('第{}张图片写入完毕'.format(i))
if __name__=='__main__':
all_xml_path = glob.glob('VOCdevkit/VOC2012/Annotations/*.xml')
tfrecord_path = 'voc_2012.tfrecords'
voc_img_path = 'VOCdevkit/VOC2012/JPEGImages'
write_to_tfrecord(all_xml_path, tfrecord_path, voc_img_path)
3 读取tfrecords文件
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
def parse_example(example_string):
feature_dict = tf.io.parse_single_example(example_string, feature_description)
image_data = tf.io.decode_jpeg(feature_dict['image'])
boxes = tf.stack([tf.sparse.to_dense(feature_dict['xmin']),
tf.sparse.to_dense(feature_dict['ymin']),
tf.sparse.to_dense(feature_dict['xmax']),
tf.sparse.to_dense(feature_dict['ymax'])], axis=1)
boxes_category = tf.sparse.to_dense(feature_dict['label'])
return image_data, feature_dict['width'], feature_dict['height'], boxes, boxes_category
if __name__ == '__main__':
raw_datasets = tf.data.TFRecordDataset('voc_2012.tfrecords')
feature_description = {
'image': tf.io.FixedLenFeature([], tf.string),
'width': tf.io.FixedLenFeature([], tf.float32),
'height': tf.io.FixedLenFeature([], tf.float32),
'xmin': tf.io.VarLenFeature(tf.float32),
'ymin': tf.io.VarLenFeature(tf.float32),
'xmax': tf.io.VarLenFeature(tf.float32),
'ymax': tf.io.VarLenFeature(tf.float32),
'label': tf.io.VarLenFeature(tf.int64),
}
raw_datasets = raw_datasets.map(parse_example)
print(raw_datasets)
plt.figure(figsize=(15, 10))
i = 0
for image, width, height, boxes, boxes_category in raw_datasets.take(12):
plt.subplot(3, 4, i+1)
plt.imshow(image)
ax = plt.gca()
for j in range(boxes.shape[0]):
rect = Rectangle((boxes[j, 0], boxes[j, 1]), boxes[j, 2]-boxes[j, 0], boxes[j, 3]-boxes[j, 1], color='r', fill=False)
ax.add_patch(rect)
i+=1
plt.show()