Tensorflow 读取TFRecord 文件小结

  1. 1.Tensorflow学习笔记-通过slim读取TFRecord文件

TFRecord文件格式的介绍:http://blog.csdn.net/lovelyaiq/article/details/78711944 
  由于slim是tensorflow的高级API,使用起来比较方便,例如在卷积或全连接层的书写时,可以大大减少代码量。使用slim读取TFRecord文件与tensorflow直接读取还是有很大的却别。 
  本文就以slim中的例子的flowers来说明。tfrecord中的格式定义为:

image_data = image_data = tf.gfile.FastGFile('img_path', 'rb').read()
def image_to_tfexample(image_data, image_format, height, width, class_id):
  return tf.train.Example(features=tf.train.Features(feature={
      'image/encoded': bytes_feature(image_data),
      'image/format': bytes_feature(image_format),
      'image/class/label': int64_feature(class_id),
      'image/height': int64_feature(height),
      'image/width': int64_feature(width),
  }))

原始图像经过处理后,生成5个文件。flowers_train_00000-of-00005.tfrecord到flowers_train_00004-of-00005.tfrecord。 
训练时,就要通过slim从这5个文件中读取数据,然后组合成batch。代码如下:

  # 第一步
  # 将example反序列化成存储之前的格式。由tf完成
  keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }
 # 第一步
 # 将反序列化的数据组装成更高级的格式。由slim完成
 items_to_handlers = {
      'image': slim.tfexample_decoder.Image('image/encoded','image/format'),
      'label': slim.tfexample_decoder.Tensor('image/class/label'),
  }
# 解码器,进行解码
decoder = slim.tfexample_decoder.TFExampleDecoder(
      keys_to_features, items_to_handlers)
# dataset对象定义了数据集的文件位置,解码方式等元信息
dataset = slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=tf.TFRecordReader,
      decoder=decoder,
      num_samples=SPLITS_TO_SIZES[split_name],#训练数据的总数
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
      num_classes=_NUM_CLASSES,
      labels_to_names=labels_to_names #字典形式,格式为:id:class_call,
      )
# provider对象根据dataset信息读取数据
provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=FLAGS.num_readers,
          common_queue_capacity=20 * FLAGS.batch_size,
          common_queue_min=10 * FLAGS.batch_size)

 # 获取数据,获取到的数据是单个数据,还需要对数据进行预处理,组合数据
 [image, label] = provider.get(['image', 'label'])
 # 图像预处理
 image = image_preprocessing_fn(image, train_image_size, train_image_size)

 images, labels = tf.train.batch(
              [image, label],
              batch_size=FLAGS.batch_size,
              num_threads=FLAGS.num_preprocessing_threads,
              capacity=5 * FLAGS.batch_size)
 labels = slim.one_hot_encoding(
              labels, dataset.num_classes - FLAGS.labels_offset)
 batch_queue = slim.prefetch_queue.prefetch_queue(
              [images, labels], capacity=2 * deploy_config.num_clones)
 # 组好后的数据
 images, labels = batch_queue.dequeue()

  至此,就可以使用images作为神经网络的输入,使用labels计算损失函数等操作。

示例:SSD上的

def slim_get_batch(num_classes, batch_size, split_name, file_pattern, num_readers, num_preprocessing_threads, image_preprocessing_fn, anchor_encoder, num_epochs=None, is_training=True):
    """Gets a dataset tuple with instructions for reading Pascal VOC dataset.

    Args:
      num_classes: total class numbers in dataset.
      batch_size: the size of each batch.
      split_name: 'train' of 'val'.
      file_pattern: The file pattern to use when matching the dataset sources (full path).
      num_readers: the max number of reader used for reading tfrecords.
      num_preprocessing_threads: the max number of threads used to run preprocessing function.
      image_preprocessing_fn: the function used to dataset augumentation.
      anchor_encoder: the function used to encoder all anchors.
      num_epochs: total epoches for iterate this dataset.
      is_training: whether we are in traing phase.

    Returns:
      A batch of [image, shape, loc_targets, cls_targets, match_scores].
    """
    if split_name not in data_splits_num:
        raise ValueError('split name %s was not recognized.' % split_name)

    # Features in Pascal VOC TFRecords.
    keys_to_features = {
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/filename': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/height': tf.FixedLenFeature([1], tf.int64),
        'image/width': tf.FixedLenFeature([1], tf.int64),
        'image/channels': tf.FixedLenFeature([1], tf.int64),
        'image/shape': tf.FixedLenFeature([3], tf.int64),
        'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
    }
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
        'filename': slim.tfexample_decoder.Tensor('image/filename'),
        'shape': slim.tfexample_decoder.Tensor('image/shape'),
        'object/bbox': slim.tfexample_decoder.BoundingBox(
                ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
        'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
        'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
        'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
    }
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)

    labels_to_names = {}
    for name, pair in VOC_LABELS.items():
        labels_to_names[pair[0]] = name

    dataset = slim.dataset.Dataset(
                data_sources=file_pattern,  #/train-*
                reader=tf.TFRecordReader,
                decoder=decoder,
                num_samples=data_splits_num[split_name],  #'train' ;'val'
                items_to_descriptions=None,
                num_classes=num_classes,
                labels_to_names=labels_to_names)  #{0: 'none', 1: 'aeroplane', 2: 'bicycle', 3: 'bird', 4: 'boat', 5: 'bottle', 6: 'bus', 7: 'car', 8: 'cat', 9: 'chair', 10: 'cow', 11: 'diningtable', 12: 'dog', 13: 'horse', 14: 'motorbike', 15: 'person', 16: 'pottedplant', 17: 'sheep', 18: 'sofa', 19: 'train', 20: 'tvmonitor'}

    with tf.name_scope('dataset_data_provider'):
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=num_readers,
            common_queue_capacity=32 * batch_size,
            common_queue_min=8 * batch_size,
            shuffle=is_training,
            num_epochs=num_epochs)

    [org_image, filename, shape, glabels_raw, gbboxes_raw, isdifficult] = provider.get(['image', 'filename', 'shape',
                                                                     'object/label',
                                                                     'object/bbox',
                                                                     'object/difficult'])

    if is_training:
        # if all is difficult, then keep the first one
        isdifficult_mask =tf.cond(tf.count_nonzero(isdifficult, dtype=tf.int32) < tf.shape(isdifficult)[0],
                                lambda : isdifficult < tf.ones_like(isdifficult),
                                lambda : tf.one_hot(0, tf.shape(isdifficult)[0], on_value=True, off_value=False, dtype=tf.bool))

        glabels_raw = tf.boolean_mask(glabels_raw, isdifficult_mask)
        gbboxes_raw = tf.boolean_mask(gbboxes_raw, isdifficult_mask)

    # Pre-processing image, labels and bboxes.

    if is_training:
        image, glabels, gbboxes = image_preprocessing_fn(org_image, glabels_raw, gbboxes_raw)
    else:
        image = image_preprocessing_fn(org_image, glabels_raw, gbboxes_raw)
        glabels, gbboxes = glabels_raw, gbboxes_raw

    gt_targets, gt_labels, gt_scores = anchor_encoder(glabels, gbboxes)

    return tf.train.batch([image, filename, shape, gt_targets, gt_labels, gt_scores],
                    dynamic_pad=False,
                    batch_size=batch_size,
                    allow_smaller_final_batch=(not is_training),
                    num_threads=num_preprocessing_threads,
                    capacity=64 * batch_size)

 

2. Tensorflow学习笔记-通过slim读取TFRecord文件

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Python读取TFRecord文件的方法如下: ```python import tensorflow as tf # 创建一个TFRecordDataset对象 dataset = tf.data.TFRecordDataset('data.tfrecord') # 定义读取函数 def parser(record): features = { 'image': tf.io.FixedLenFeature([], dtype=tf.string), 'label': tf.io.FixedLenFeature([], dtype=tf.int64) } parsed = tf.io.parse_single_example(record, features) image = tf.io.decode_jpeg(parsed['image'], channels=3) label = parsed['label'] return image, label # 应用读取函数到每个record dataset = dataset.map(parser) # 创建迭代器 iterator = dataset.make_one_shot_iterator() # 获取数据 image, label = iterator.get_next() ``` 以上代码演示了如何读取名为`data.tfrecord`的TFRecord文件,并解析其中的图像和标签信息。在解析函数`parser`中,我们先定义了TFRecord文件中包含的特征信息,然后使用`tf.io.parse_single_example`函数解析单个record,并对图像数据进行解码。最后,我们使用`map`函数将解析函数应用到每个record上。 当然,如果您使用的是PyTorch,也可以使用以下代码读取TFRecord文件: ```python import torch import torchvision.datasets as datasets import torchvision.transforms as transforms # 定义解析函数 def parser(record): features = { 'image': tf.io.FixedLenFeature([], dtype=tf.string), 'label': tf.io.FixedLenFeature([], dtype=tf.int64) } parsed = tf.io.parse_single_example(record, features) image = tf.io.decode_jpeg(parsed['image'], channels=3) label = parsed['label'] return image, label # 创建数据集对象 dataset = datasets.DatasetFolder( 'data.tfrecord', loader=lambda x: torch.load(x), extensions=('tfrecord') ) # 应用解析函数到每个record dataset.transform = transforms.Compose([ parser ]) # 创建数据加载器 dataloader = torch.utils.data.DataLoader( dataset, batch_size=32, shuffle=True ) # 获取数据 for images, labels in dataloader: # 使用数据进行训练或预测 pass ``` 以上代码演示了如何使用PyTorch的`DatasetFolder`读取TFRecord文件,并使用解析函数`parser`解析图像和标签信息。最后,我们创建了一个数据加载器,并使用其中的数据进行训练或预测。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值