之前已经讲过了generator,这次是要建立一个更详细的框架,数据到底怎么被处理的。可以作为样例学习。写的并不详细,因为你如果要做更深的工作,你需要很高自主能力,大多数人都具备,所以我就不废话了(主要是忙)。
源数据:图片/矩阵
目标数据:tensorflow 标准的Dataset
主要过程:
- 源数据->build_voc_data.py->tfrecord
- tfrecord->data_generator.py->Dataset
第一步将数据转换成tfrecord
直接读入的是二进制图像数据
之后利用类build_data.py直接将其转换成feature存入tfrecord,其实就是转换一下格式成为feature,然后用tf.train.Example画个格子装进去。
所以如果要多存入一些数据,比如深度数据,比如视频数据,那就需要改三个点:
1.build_data.py :def image_seg_to_tfexample加入你自己的定义。注意你的数据如果是整数就用_int64,浮点数就自己写个浮点数的,或者用string
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': _bytes_list_feature(image_data),
'image/filename': _bytes_list_feature(filename),
'image/format': _bytes_list_feature(
_IMAGE_FORMAT_MAP[FLAGS.image_format]),
'image/height': _int64_list_feature(height),
'image/width': _int64_list_feature(width),
'image/channels': _int64_list_feature(3),
'image/segmentation/class/encoded': (
_bytes_list_feature(seg_data)),
'image/segmentation/class/format': _bytes_list_feature(
FLAGS.label_format),
}))
2.build_voc_data.py:读入你自己的数据,下面就是读取图像和标注。你需要写你自己的数据读入。
image_filename = os.path.join(
FLAGS.image_folder, filenames[i] + '.' + FLAGS.image_format)
image_data = tf.gfile.FastGFile(image_filename, 'rb').read()
height, width = image_reader.read_image_dims(image_data)
# Read the semantic segmentation annotation.
seg_filename = os.path.join(
FLAGS.semantic_segmentation_folder,
filenames[i] + '.' + FLAGS.label_format)
seg_data = tf.gfile.FastGFile(seg_filename, 'rb').read()
seg_height, seg_width = label_reader.read_image_dims(seg_data)
3.build_voc_data.py:传递你的数据直接转换。
example = build_data.image_seg_to_tfexample(
image_data, filenames[i], height, width, seg_data)
第二步通过tfrecord建立Dataset
1.主函数是data_generator,调用了input_preprocess.py,preprocess调用了core.preprocess_utils.py
data_generator 读取tfrecord 需要改变features字典
features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/filename':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/height':
tf.FixedLenFeature((), tf.int64, default_value=0),
'image/width':
tf.FixedLe