Tensorflow学习笔记-通过slim读取TFRecord文件

  TFRecord文件格式的介绍:http://blog.csdn.net/lovelyaiq/article/details/78711944
  由于slim是tensorflow的高级API,使用起来比较方便,例如在卷积或全连接层的书写时,可以大大减少代码量。使用slim读取TFRecord文件与tensorflow直接读取还是有很大的却别。
  本文就以slim中的例子的flowers来说明。tfrecord中的格式定义为:

image_data = image_data = tf.gfile.FastGFile('img_path', 'rb').read()
def image_to_tfexample(image_data, image_format, height, width, class_id):
  return tf.train.Example(features=tf.train.Features(feature={
      'image/encoded': bytes_feature(image_data),
      'image/format': bytes_feature(image_format),
      'image/class/label': int64_feature(class_id),
      'image/height': int64_feature(height),
      'image/width': int64_feature(width),
  }))

原始图像经过处理后,生成5个文件。flowers_train_00000-of-00005.tfrecord到flowers_train_00004-of-00005.tfrecord。
训练时,就要通过slim从这5个文件中读取数据,然后组合成batch。代码如下:

  # 第一步
  # 将example反序列化成存储之前的格式。由tf完成
  keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }
 # 第一步
 # 将反序列化的数据组装成更高级的格式。由slim完成
 items_to_handlers = {
      'image': slim.tfexample_decoder.Image('image/encoded','image/format'),
      'label': slim.tfexample_decoder.Tensor('image/class/label'),
  }
# 解码器,进行解码
decoder = slim.tfexample_decoder.TFExampleDecoder(
      keys_to_features, items_to_handlers)
# dataset对象定义了数据集的文件位置,解码方式等元信息
dataset = slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=tf.TFRecordReader,
      decoder=decoder,
      num_samples=SPLITS_TO_SIZES[split_name],#训练数据的总数
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
      num_classes=_NUM_CLASSES,
      labels_to_names=labels_to_names #字典形式,格式为:id:class_call,
      )
# provider对象根据dataset信息读取数据
provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=FLAGS.num_readers,
          common_queue_capacity=20 * FLAGS.batch_size,
          common_queue_min=10 * FLAGS.batch_size)

 # 获取数据,获取到的数据是单个数据,还需要对数据进行预处理,组合数据
 [image, label] = provider.get(['image', 'label'])
 # 图像预处理
 image = image_preprocessing_fn(image, train_image_size, train_image_size)

 images, labels = tf.train.batch(
              [image, label],
              batch_size=FLAGS.batch_size,
              num_threads=FLAGS.num_preprocessing_threads,
              capacity=5 * FLAGS.batch_size)
 labels = slim.one_hot_encoding(
              labels, dataset.num_classes - FLAGS.labels_offset)
 batch_queue = slim.prefetch_queue.prefetch_queue(
              [images, labels], capacity=2 * deploy_config.num_clones)
 # 组好后的数据
 images, labels = batch_queue.dequeue()

  至此,就可以使用images作为神经网络的输入,使用labels计算损失函数等操作。

  • 12
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 12
    评论
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值