Object Detection中使用tf.data.TFRecordDataset读取TFRecord文件

本文参考tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用

1. 写入tfrecord

import tensorflow as tf
def int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def int64_list_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def bytes_list_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def read_examples_list(path):
  """读取数据集分割文件,trainval.txt。
  假设每一行存储为:xyz 3 (第一个为文件名)
  would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored).
  """
  with tf.gfile.GFile(path) as fid:
    lines = fid.readlines()
  return [line.strip().split(' ')[0] for line in lines]

def recursive_parse_xml_to_dict(xml):
  """解析pascal voc形式的xml文件
  We assume that `object` tags are the only ones that can appear
  multiple times at the same level of a tree.
  Args:
    xml: xml tree obtained by parsing XML file contents using lxml.etree
  Returns:
    Python dictionary holding XML contents.
  """
  if not xml:
    return {xml.tag: xml.text}
  result = {}
  for child in xml:
    child_result = recursive_parse_xml_to_dict(child)
    if child.tag != 'object':
      result[child.tag] = child_result[child.tag]
    else:
      if child.tag not in result:
        result[child.tag] = []
      result[child.tag].append(child_result[child.tag])
  return {xml.tag: result}

例子:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
#把数据写入Example
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
#把所有数据写入tfrecord文件
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 for inx in range(ndatas):
 exmp = get_tfrecords_example(feats[inx], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

2.读取和解析tfrecord文件

  • 导入库文件
import os, sys, PIL, matplotlib, itertools
matplotlib.use(backend="Agg")
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
LIB_DIR = r"E:\MyCollectionFinished\Git-models-1.12.0\research\object_detection"
sys.path.append(LIB_DIR)
from inference import detection_inference
from utils import dataset_util
from object_detection.core import data_decoder
from object_detection.core import standard_fields as fields
from object_detection.protos import input_reader_pb2
from object_detection.utils import label_map_util

DAT_DIR = r"E:\MyCollectionFinished\Git-models-1.12.0\research\object_detection\my_TLR_DCR_v1_0\datasets\TLR_DCR_v1_0\TFRecordDet_v2_0"
input_tfrecord_paths = [os.path.join(DAT_DIR, "tlr_dcr_val_v2_0.tfrecord")]
  • 定义单条example的解析文件
def parse_one_example_a(serialized_example):
    keys_to_features = {
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/filename': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/key/sha256': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/source_id': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/height': tf.FixedLenFeature((), tf.int64, default_value=1),
        'image/width': tf.FixedLenFeature((), tf.int64, default_value=1),
        # Image-level labels.
        'image/class/text': tf.VarLenFeature(tf.string),
        'image/class/label': tf.VarLenFeature(tf.int64),
        # Object boxes and classes. 
        'image/object/bbox/xmin': tf.VarLenFeature(tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(tf.float32),
        'image/object/class/label': tf.VarLenFeature(tf.int64),
        'image/object/class/text': tf.VarLenFeature(tf.string),
        'image/object/area': tf.VarLenFeature(tf.float32),
        'image/object/is_crowd': tf.VarLenFeature(tf.int64),
        'image/object/difficult': tf.VarLenFeature(tf.int64),
        'image/object/group_of': tf.VarLenFeature(tf.int64),
        'image/object/weight': tf.VarLenFeature(tf.float32),
        'image/object/mask': tf.VarLenFeature(tf.string),     # PNG Mask.
        'image/object/weight1': tf.VarLenFeature(tf.float32),
    }
    features = tf.parse_single_example(serialized_example, keys_to_features)
    keys = features.keys()
    tensor_dict = {}
    tensor_dict[fields.InputDataFields.image] = tf.image.decode_image(features['image/encoded'])
    tensor_dict[fields.InputDataFields.source_id] = tf.cast(features['image/source_id'], tf.string)
    tensor_dict[fields.InputDataFields.groundtruth_boxes+"/xmin"] = features['image/object/bbox/xmin']
    tensor_dict['image/object/weight1'] = tf.cast(features['image/object/weight1'], tf.float32)
    return tensor_dict
  • 定义数据集,生成迭代器
dataset = tf.data.TFRecordDataset(input_tfrecord_paths)
dataset = dataset.map(parse_one_example_a).repeat(1).shuffle(10000)
train_iterator = dataset.make_initializable_iterator()
data_tensor = train_iterator.get_next()
  • 测试输出
with tf.Session() as sess:
    sess.run(train_iterator.initializer)
    inputs = sess.run(data_tensor)
    print(inputs.keys())
    print(inputs['source_id'].decode('utf-8'))
    print(inputs['image'].shape)
    print(inputs['image/object/weight1'])
    print(type(inputs['image/object/weight1']))
  • 输出结果
    (注意这里添加了一条tfrecord中不存在的记录,‘image/object/weight1’,会打印出一个奇怪的内容SparseTensorValue。所以,仍然无法排除比如tfrecord中的key错误的情况。
dict_keys(['image', 'source_id', 'groundtruth_boxes/xmin', 'image/object/weight1'])
VID_20200419_120214_00055.jpg
(1280, 720, 3)
SparseTensorValue(indices=array([], shape=(0, 1), dtype=int64), values=array([], dtype=float32), dense_shape=array([0], dtype=int64))
<class 'tensorflow.python.framework.sparse_tensor.SparseTensorValue'>
import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.decode_raw(feats['feature'], tf.float32)
 label = feats['label']
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 # when batch_size can't be divided by nDatas, like 56,
 # there will be a batch data with nums less than batch_size
 
# training dataset
nDatasTrain = 46750
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).batch(batch_size) # make sure repeat is ahead batch
  # this is different from dataset.shuffle(1000).batch(batch_size).repeat(epochs)
  # the latter means that there will be a batch data with nums less than batch_size for each epoch
  # if when batch_size can't be divided by nDatas.
nBatchs = nDatasTrain*epochs//batch_size
 
# evalation dataset
nDatasVal = 8250
dataset_val = get_dataset(val_f)
dataset_val = dataset_val.batch(nDatasVal).repeat(nBatchs//100*2)
 
# test dataset
nDatasTest = 10000
dataset_test = get_dataset(test_f)
dataset_test = dataset_test.batch(nDatasTest)
 
# make dataset iterator
iter_train = dataset_train.make_one_shot_iterator()
iter_val  = dataset_val.make_one_shot_iterator()
iter_test  = dataset_test.make_one_shot_iterator()
 
# make feedable iterator
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, \
 dataset_train.output_types, dataset_train.output_shapes)
x, y_, _ = iterator.get_next()
train_op, loss, eval_op = model(x, y_)
init = tf.initialize_all_variables()
 
# summary
logdir = './logs/m4d2a'
def summary_op(datapart='train'):
 tf.summary.scalar(datapart + '-loss', loss)
 tf.summary.scalar(datapart + '-eval', eval_op)
 return tf.summary.merge_all() 
summary_op_train = summary_op()
summary_op_test = summary_op('val')
 
with tf.Session() as sess:
 sess.run(init)
 handle_train, handle_val, handle_test = sess.run(\
 [x.string_handle() for x in [iter_train, iter_val, iter_test]])
    _, cur_loss, cur_train_eval, summary = sess.run([train_op, loss, eval_op, summary_op_train], \
  feed_dict={handle: handle_train, keep_prob: 0.5} )
    cur_val_loss, cur_val_eval, summary = sess.run([loss, eval_op, summary_op_test], \
  feed_dict={handle: handle_val, keep_prob: 1.0})
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值