Tensorflow之TFRecord制作——VOC数据为例

TFRecord 是TensorFlow专用的数据处理文件,方便在训练的时候快速读取和转移
现在就基于VOC数据集介绍一下。

1、生成TFRecord

首先就是封装数据集,其具体方法如下:

在这里插入图片描述
具体实现代码为:

with tf.io.gfile.GFile(full_path, 'rb') as fid:
      encoded_jpg = fid.read()
      
def int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def int64_list_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def bytes_list_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/source_id': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
      'image/encoded': dataset_util.bytes_feature(encoded_jpg),
      'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      'image/object/class/label': dataset_util.int64_list_feature(classes),
      'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
      'image/object/truncated': dataset_util.int64_list_feature(truncated),
      'image/object/view': dataset_util.bytes_list_feature(poses),
  }))
    return example

#写成tfrecord文件
writer = tf.io.TFRecordWriter(output_path)
for idx, example in enumerate(examples_list):

	tf_example = dict_to_tf_example(data, data_dir, VOC_NAME_LABEL,
	                                            ignore_difficult_instances)
	writer.write(tf_example.SerializeToString())

writer.close()

2、解析TFRecord

就是定义好解析字典IMAGE_FEATURE_MAP 和解析方法parse_example,就是对出来的数据进行组合处理,最终输出结果

需要主要的是以下2个方面对应
tf.io.FixedLenFeature([], tf.int64) ==> tf.Tensor(375, shape=(), dtype=int64)

tf.io.VarLenFeature(tf.float32) ==> SparseTensor(indices=tf.Tensor([[0]], shape=(1, 1), dtype=int64), values=tf.Tensor([12], shape=(1,), dtype=int64), dense_shape=tf.Tensor([1], shape=(1,), dtype=int64))

#解析对应格式
IMAGE_FEATURE_MAP = {
    'image/height': tf.io.FixedLenFeature([], tf.int64),
    'image/width': tf.io.FixedLenFeature([], tf.int64),
    'image/filename': tf.io.FixedLenFeature([], tf.string),
    'image/source_id': tf.io.FixedLenFeature([], tf.string),
    'image/key/sha256': tf.io.FixedLenFeature([], tf.string),
    'image/encoded': tf.io.FixedLenFeature([], tf.string),
    'image/format': tf.io.FixedLenFeature([], tf.string),
    'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32), # 如果数据中存放的list长度大于1, 表示数据是不定长的, 使用VarLenFeature解析
    'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
    'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
    'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
    'image/object/class/text': tf.io.VarLenFeature(tf.string),
    'image/object/class/label': tf.io.VarLenFeature(tf.int64),
    'image/object/difficult': tf.io.VarLenFeature(tf.int64),
    'image/object/truncated': tf.io.VarLenFeature(tf.int64),
    'image/object/view': tf.io.VarLenFeature(tf.string),
}

def parse_example(serialized_example,height=512,width=512):
  #解析序列化的example
  x = tf.io.parse_single_example(serialized_example, IMAGE_FEATURE_MAP)
  #然后就可以根据字典获取值了
  x_train = tf.image.decode_jpeg(x['image/encoded'], channels=3)
  x_train = tf.image.resize(x_train, (height,width))
#  class_text = x['image/object/class/text'] # 原始类型是SparseTensor, https://blog.csdn.net/JsonD/article/details/73105490
#  class_text = tf.sparse.to_dense(x['image/object/class/text'], default_value='')
  labels = tf.cast(tf.sparse.to_dense(x['image/object/class/label']), tf.float32)
  y_train = tf.stack([tf.sparse.to_dense(x['image/object/bbox/xmin']), # shape: [m]
                      tf.sparse.to_dense(x['image/object/bbox/ymin']), # shape: [m]
                      tf.sparse.to_dense(x['image/object/bbox/xmax']), # shape: [m]
                      tf.sparse.to_dense(x['image/object/bbox/ymax']), # shape: [m]
                      labels  # shape: [m]
                      ], axis=1) # shape:[m, 5], m是图片中目标的个数, 每张图片的m可能不一样

  # 每个图片最多包含100个目标
  paddings = [[0, 100 - tf.shape(y_train)[0]], [0, 0]] # 上下左右分别填充0, 100 - tf.shape(y_train)[0], 0, 0
  # The padded size of each dimension D of the output is:
  # paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1]
  y_train = tf.pad(y_train, paddings)
  return x_train, y_train


def _parse_function(example_proto):
    # Parse the input `tf.Example` proto using the dictionary above.
    return tf.io.parse_single_example(example_proto, IMAGE_FEATURE_MAP)


if __name__ == '__main__':
    dataset = tf.data.TFRecordDataset(filenames=['/data/data/VOC2007/train.tfrecord'])
    print(dataset)
    # raw_eaxmple = next(iter(dataset))
    # parsed = tf.train.Example.FromString(raw_eaxmple.numpy())
    # print(parsed)

    # for index ,record in enumerate(dataset):
    #     example = tf.io.parse_single_example(record,features=IMAGE_FEATURE_MAP)
    #     for key,value in example.items():
    #         print(key,'=>',value)

    # parsed_dataset = dataset.map(_parse_function)
    parsed_dataset = dataset.map(parse_example)  #map就可以对每个序列化的example进行解析

    for parsed_record in parsed_dataset.take(10):
        # print(repr(parsed_record))
        print(repr(parsed_record))
        print('=========')

本文主要参考了yinghuang/yolov2-tensorflow2

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
VOC2007数据集转换为TFRecord文件需要以下步骤: 1. 下载VOC2007数据集并解压缩。 2. 安装TensorFlow和Pillow库。 3. 编写脚本将VOC2007数据集转换为TFRecord文件。 以下是一个简单的Python脚本示例,可以将VOC2007数据集转换为TFRecord文件: ```python import tensorflow as tf import os import io import xml.etree.ElementTree as ET from PIL import Image def create_tf_example(example): # 读取图像文件 img_path = os.path.join('VOCdevkit/VOC2007/JPEGImages', example['filename']) with tf.io.gfile.GFile(img_path, 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) image = Image.open(encoded_jpg_io) width, height = image.size # 读取标注文件 xml_path = os.path.join('VOCdevkit/VOC2007/Annotations', example['filename'].replace('.jpg', '.xml')) with tf.io.gfile.GFile(xml_path, 'r') as fid: xml_str = fid.read() xml = ET.fromstring(xml_str) # 解析标注文件 xmins = [] xmaxs = [] ymins = [] ymaxs = [] classes_text = [] classes = [] for obj in xml.findall('object'): class_name = obj.find('name').text classes_text.append(class_name.encode('utf8')) classes.append(label_map[class_name]) bbox = obj.find('bndbox') xmins.append(float(bbox.find('xmin').text) / width) ymins.append(float(bbox.find('ymin').text) / height) xmaxs.append(float(bbox.find('xmax').text) / width) ymaxs.append(float(bbox.find('ymax').text) / height) # 构造TFRecord Example tf_example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])), 'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])), 'image/filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[example['filename'].encode('utf8')])), 'image/source_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[example['filename'].encode('utf8')])), 'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded_jpg])), 'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=['jpeg'.encode('utf8')])), 'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmins)), 'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmaxs)), 'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymins)), 'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=ymaxs)), 'image/object/class/text': tf.train.Feature(bytes_list=tf.train.BytesList(value=classes_text)), 'image/object/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=classes)), })) return tf_example # 将VOC2007数据集转换为TFRecord文件 def create_tf_record(output_file): examples = [...] # 从VOC2007数据集读取实例 writer = tf.io.TFRecordWriter(output_file) for example in examples: tf_example = create_tf_example(example) writer.write(tf_example.SerializeToString()) writer.close() label_map = {...} # 标签映射 output_file = 'voc2007_train.tfrecord' create_tf_record(output_file) ``` 其中`create_tf_example`函数将一个VOC2007样本转换为TFRecord Example,`create_tf_record`函数将整个VOC2007数据集转换为TFRecord文件。在这个例子中,我们假设VOC2007数据集已经被解压缩到`VOCdevkit/VOC2007`目录下,标签映射已经定义为`label_map`变量。你需要根据自己的实际情况修改这些变量。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值