tensorfow学习(一) ——tf.data.TFRecordDataset的使用

之前一直用tfRecord的队列读入格式, 偶然逛官网发现有更方便的tf.data。tensorflow官网其实已经给了很完整的说明,包括各种的数据格式,其他数据可以看tensorflow中文文档

一. 创建dataSet

  1. 定义-tf.data.TFRecordDataset

    # fileNames指的是你要用tfrecord文件的路径
    dataset = tf.data.TFRecordDataset(filenames)
    
  2. dataset.repeat(num)
    num为空表示无限重复下去
    不设置则表示只重复一次

  3. dataset.shuffle(shuffle_num)
    shuffle_num指的打乱数,我的理解是一个打乱顺序队列,取出后可以再填入
    4.** dataset.batch(batch-size)**
    都懂,就是设置batch-size

  4. dataset.map()
    我的理解是一个预处理,里面是一个函数对象,可以用lambda表达式代替。如果用tfRecord数据,里面的函数就是解析数据的方式

二、消耗数据-iterator

iterator有很多种,分为单次可初始化可馈送可重新初始化

  • 单次
    单次迭代器是最简单的迭代器形式,仅支持对数据集进行一次迭代,不需要显式初始化。单次迭代器可以处理基于队列的现有输入管道支持的几乎所有情况,但它们不支持参数化(也就是不能给他们赋不同的dataSet-我的理解)

      dataset = tf.data.Dataset.range(100)
      iterator = dataset.make_one_shot_iterator()
      next_element = iterator.get_next()
      for i in range(100):
      	 value = sess.run(next_element)
    		assert i == value
    
  • 可初始化
    需要先运行显式 iterator.initializer 操作,然后才能使用可初始化迭代器。虽然有些不便,但它允许您使用一个或多个 tf.placeholder() 张量(可在初始化迭代器时馈送)参数化数据集的定义。

    max_value = tf.placeholder(tf.int64, shape=[])
    dataset = tf.data.Dataset.range(max_value)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    
        # Initialize an iterator over a dataset with 10 elements.
        sess.run(iterator.initializer, feed_dict={max_value: 10})
        for i in range(10):
          value = sess.run(next_element)
          assert i == value
        
        # Initialize the same iterator over a dataset with 100 elements.
        sess.run(iterator.initializer, feed_dict={max_value: 100})
        for i in range(100):
          value = sess.run(next_element)
          assert i == value
    
  • 可重新初始化

    iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
                                              training_dataset.output_shapes)
    
  • 可馈送

    iterator  =	tf.data.Iterator.from_string_handle(
      handle, training_dataset.output_types, training_dataset.output_shapes)
    

三、使用

使用时,当一个数据集到达末尾时候,会引发tf.errors.OutOfRangeError, 捕获这个error即代表数据集结束,取数据用next_element = iterator.get_next()

四 、例程

我个人用官方的感觉有点复杂,我个人理解并使用的是,用可重新初始化迭代器初始化训练数据(为了可以有训练完一个数据集的信号),用单次迭代器配合无限重复次数使用验证集(我的程序会运行一次训练集的同时运行测试集判断有没有过拟合)。话不多说,看代码吧。

  • 创建数据集函数
def create_dataset(filenames, batch_size=8, is_shuffle=False, n_repeats=0):
    """

    :param filenames: record file names
    :param batch_size:
    :param is_shuffle: 是否打乱数据
    :param n_repeats:
    :return:
    """
    dataset = tf.data.TFRecordDataset(filenames)
    if n_repeats > 0:
        dataset = dataset.repeat(n_repeats)         # for train
    if n_repeats == -1:
        dataset = dataset.repeat()  # for val to
    dataset = dataset.map(lambda x: parse_single_exmp(x, labels_nums=NUM_CLASS))
    if is_shuffle:
        dataset = dataset.shuffle(10000)            # shuffle
    dataset = dataset.batch(batch_size)
    return dataset
  • 预处理tfRecord解析函数
def parse_single_exmp(serialized_example,labels_nums=2):
    """
    解析tf.record
    :param serialized_example:
    :param opposite: 是否将图片取反
    :return:
    """
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'depth': tf.FixedLenFeature([], tf.int64),
            'label': tf.FixedLenFeature([], tf.int64)
        }
    )
    tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#获得图像原始的数据
    tf_height = features['height']
    tf_width = features['width']
    tf_depth = features['depth']
    tf_label = tf.cast(features['label'], tf.int32)
    # PS:恢复原始图像数据,reshape的大小必须与保存之前的图像shape一致,否则出错
    # tf_image=tf.reshape(tf_image, [-1])    # 转换为行向量
    tf_image =tf.reshape(tf_image, [224, 224, 3]) # 设置图像的维度
    tf_image = tf.cast(tf_image, tf.float32)
    tf_image = prepeocess(
        tf_image, choice=True)
    tf_label = tf.one_hot(tf_label, labels_nums, 1, 0)
    print(tf_image)
    return tf_image, tf_label

  • 运行
import tensorflow as tf
import os
import tensorflow.contrib.slim as slim

# 读取tfrecord文件并列成列表,train_dir是存放的路径
train_file_names = [os.path.join(train_dir, i) for i in os.listdir(train_dir)]
val_file_names = [os.path.join(val_dir,  i) for i in os.listdir(val_dir)]

# 定义数据集
training_dataset = create_dataset(train_file_names, batch_size=BATCH_SIZE,
                                      is_shuffle=True, n_repeats=0)  # train_filename
# train_dataset 用epochs控制循环
validation_dataset = create_dataset(val_file_names, batch_size=BATCH_SIZE,
                                  is_shuffle=False, n_repeats=-1)  # val

# 定义迭代器
train_iterator = training_dataset.make_initializable_iterator()
# make_initializable_iterator 每个epoch都需要初始化
val_iterator = validation_dataset.make_one_shot_iterator()
# make_one_shot_iterator不需要初始化,根据需要不停循环
train_images, train_labels = train_iterator.get_next()
val_images, val_labels = val_iterator.get_next()

 for epoch in range(NUM_EPOCHS):
	   print('Starting epoch %d / %d' % (epoch + 1, NUM_EPOCHS))
       sess.run(train_iterator.initializer)
       while True:
           try:
               train_batch_images, train_batch_labels \
                   = sess.run([train_images, train_labels])
               _, train_loss, train_acc = sess.run([fc8_train_op,loss, accuracy],
                                                feed_dict={is_training: True,
                                                           images: train_batch_images,
                                                           labels: train_batch_labels})
               val_batch_images, val_batch_label = \
                   sess.run([val_images, val_labels])
               val_loss, val_acc = sess.run([loss, accuracy],
                                            feed_dict={is_training: False,
                                                       images: val_batch_images,
                                                       labels: val_batch_label})
               # step = sess.run(global_step)
               # print("global_step:{0}".format(step))
               print("epoch:{0}, train loss:{1},train-acc:{2}".format(epoch, train_loss, train_acc))
               print("epoch:{0}, val loss:{0},val-acc:{1}".format(epoch, val_loss, val_acc))
           except tf.errors.OutOfRangeError:
               break

  • 16
    点赞
  • 48
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值