import tensorflow as tf
import pascalvoc_to_tfrecords
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
'dataset_name', 'pascalvoc',
'The name of the dataset to convert.')
tf.app.flags.DEFINE_string(
'dataset_dir', r"C:\Users\ADMIN\Desktop\mty\datasets", #载入数据的地址(主文件夹)
'Directory where the original dataset is stored.')
tf.app.flags.DEFINE_string(
'output_name', 'pascalvoc', #输出文件的名字
'Basename used for TFRecords output files.')
tf.app.flags.DEFINE_string(
'output_dir', r"C:\Users\ADMIN\Desktop\mty\datasets\voc",#保存数据的地址
'Output directory where to store TFRecords files.')
def main(_):
if not FLAGS.dataset_dir:
raise ValueError('You must supply the dataset directory with --dataset_dir')
print('Dataset directory:', FLAGS.dataset_dir)
print('Output directory:', FLAGS.output_dir)
if FLAGS.dataset_name == 'pascalvoc':
pascalvoc_to_tfrecords.run(FLAGS.dataset_dir, FLAGS.output_dir, FLAGS.output_name)
else:
raise ValueError('Dataset [%s] was not recognized.' % FLAGS.dataset_name)
if __name__ == '__main__':
tf.app.run()
pascalvoc_to_tfrecords
import os
import sys
import random
import numpy as np
import tensorflow as tf
import xml.etree.ElementTree as ET
from dataset_utils import int64_feature, float_feature, bytes_feature
from pascalvoc_common import VOC_LABELS
# Original dataset organisation.
# 给定xml标注数据所在的文件夹名称(文件夹中必须包含Annotations,JPEGImages这两个子文件夹)
DIRECTORY_ANNOTATIONS = 'Annotations/'
# 给定图像数据所在的文件夹名称
DIRECTORY_IMAGES = 'JPEGImages/'
# TFRecords convertion parameters.
RANDOM_SEED = 4242
# 给定每个tfrecord文件存储多少原始图像的数据
SAMPLES_PER_FILES = 2
def _process_image(directory, name):
"""Process a image and annotation file.
Args:
filename: string, path to an image file e.g., '/path/to/example.JPG'.
coder: instance of ImageCoder to provide TensorFlow image coding utils.
Returns:
image_buffer: string, JPEG encoding of RGB image.
height: integer, image height in pixels.
width: integer, image width in pixels.
"""
# Read the image file.
# 1. 构造图像路径
filename = os.path.join(directory, DIRECTORY_IMAGES, name + '.jpg')
# 2. 加载数据
image_data = tf.gfile.FastGFile(filename, 'rb').read()
# Read the XML annotation file.
# 3. 构造xml文件路径
filename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml')
# 4. xml文件解析
tree = ET.parse(filename)
# a. 得到xml文件对应的根节点
root = tree.getroot()
# Image shape.
# b. 获取图像大小信息
# 从根节点中获取第一个size子标签
size = root.find('size')
shape = [int(size.find('height').text),
int(size.find('width').text),
int(size.find('depth').text)]
# Find annotations.
# c. 获取边框信息
bboxes = []
labels = []
labels_text = []
difficult = []
truncated = []
# 从根节点中获取所有object标签的值(返回的是一个列表)
for obj in root.findall('object'):
# a. 获取当前物体的所属类别名称
label = obj.find('name').text
# b. 根据标签,将标签转换为索引id和byte数组,并添加集合中
labels.append(int(VOC_LABELS[label][0])) #注意这边VOC_LABELS必须和用标注工具标注时候的label相同
labels_text.append(label.encode('ascii'))
# c. 获取difficult标记信息,用于判断当前这个物体是否属于比较难预测的物体
if obj.find('difficult'):
difficult.append(int(obj.find('difficult').text))
else:
# 0表示预测不难
difficult.append(0)