之前一直用tfRecord的队列读入格式, 偶然逛官网发现有更方便的tf.data。tensorflow官网其实已经给了很完整的说明,包括各种的数据格式,其他数据可以看tensorflow中文文档
一. 创建dataSet
-
定义-tf.data.TFRecordDataset
# fileNames指的是你要用tfrecord文件的路径 dataset = tf.data.TFRecordDataset(filenames)
-
dataset.repeat(num)
num为空表示无限重复下去
不设置则表示只重复一次 -
dataset.shuffle(shuffle_num)
shuffle_num指的打乱数,我的理解是一个打乱顺序队列,取出后可以再填入
4.** dataset.batch(batch-size)**
都懂,就是设置batch-size -
dataset.map()
我的理解是一个预处理,里面是一个函数对象,可以用lambda表达式代替。如果用tfRecord数据,里面的函数就是解析数据的方式
二、消耗数据-iterator
iterator有很多种,分为单次,可初始化,可馈送、可重新初始化
-
单次
单次迭代器是最简单的迭代器形式,仅支持对数据集进行一次迭代,不需要显式初始化。单次迭代器可以处理基于队列的现有输入管道支持的几乎所有情况,但它们不支持参数化(也就是不能给他们赋不同的dataSet-我的理解)dataset = tf.data.Dataset.range(100) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() for i in range(100): value = sess.run(next_element) assert i == value
-
可初始化
需要先运行显式 iterator.initializer 操作,然后才能使用可初始化迭代器。虽然有些不便,但它允许您使用一个或多个 tf.placeholder() 张量(可在初始化迭代器时馈送)参数化数据集的定义。max_value = tf.placeholder(tf.int64, shape=[]) dataset = tf.data.Dataset.range(max_value) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() # Initialize an iterator over a dataset with 10 elements. sess.run(iterator.initializer, feed_dict={max_value: 10}) for i in range(10): value = sess.run(next_element) assert i == value # Initialize the same iterator over a dataset with 100 elements. sess.run(iterator.initializer, feed_dict={max_value: 100}) for i in range(100): value = sess.run(next_element) assert i == value
-
可重新初始化
iterator = tf.data.Iterator.from_structure(training_dataset.output_types, training_dataset.output_shapes)
-
可馈送
iterator = tf.data.Iterator.from_string_handle( handle, training_dataset.output_types, training_dataset.output_shapes)
三、使用
使用时,当一个数据集到达末尾时候,会引发tf.errors.OutOfRangeError, 捕获这个error即代表数据集结束,取数据用next_element = iterator.get_next()
四 、例程
我个人用官方的感觉有点复杂,我个人理解并使用的是,用可重新初始化迭代器初始化训练数据(为了可以有训练完一个数据集的信号),用单次迭代器配合无限重复次数使用验证集(我的程序会运行一次训练集的同时运行测试集判断有没有过拟合)。话不多说,看代码吧。
- 创建数据集函数
def create_dataset(filenames, batch_size=8, is_shuffle=False, n_repeats=0):
"""
:param filenames: record file names
:param batch_size:
:param is_shuffle: 是否打乱数据
:param n_repeats:
:return:
"""
dataset = tf.data.TFRecordDataset(filenames)
if n_repeats > 0:
dataset = dataset.repeat(n_repeats) # for train
if n_repeats == -1:
dataset = dataset.repeat() # for val to
dataset = dataset.map(lambda x: parse_single_exmp(x, labels_nums=NUM_CLASS))
if is_shuffle:
dataset = dataset.shuffle(10000) # shuffle
dataset = dataset.batch(batch_size)
return dataset
- 预处理tfRecord解析函数
def parse_single_exmp(serialized_example,labels_nums=2):
"""
解析tf.record
:param serialized_example:
:param opposite: 是否将图片取反
:return:
"""
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64)
}
)
tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#获得图像原始的数据
tf_height = features['height']
tf_width = features['width']
tf_depth = features['depth']
tf_label = tf.cast(features['label'], tf.int32)
# PS:恢复原始图像数据,reshape的大小必须与保存之前的图像shape一致,否则出错
# tf_image=tf.reshape(tf_image, [-1]) # 转换为行向量
tf_image =tf.reshape(tf_image, [224, 224, 3]) # 设置图像的维度
tf_image = tf.cast(tf_image, tf.float32)
tf_image = prepeocess(
tf_image, choice=True)
tf_label = tf.one_hot(tf_label, labels_nums, 1, 0)
print(tf_image)
return tf_image, tf_label
- 运行
import tensorflow as tf
import os
import tensorflow.contrib.slim as slim
# 读取tfrecord文件并列成列表,train_dir是存放的路径
train_file_names = [os.path.join(train_dir, i) for i in os.listdir(train_dir)]
val_file_names = [os.path.join(val_dir, i) for i in os.listdir(val_dir)]
# 定义数据集
training_dataset = create_dataset(train_file_names, batch_size=BATCH_SIZE,
is_shuffle=True, n_repeats=0) # train_filename
# train_dataset 用epochs控制循环
validation_dataset = create_dataset(val_file_names, batch_size=BATCH_SIZE,
is_shuffle=False, n_repeats=-1) # val
# 定义迭代器
train_iterator = training_dataset.make_initializable_iterator()
# make_initializable_iterator 每个epoch都需要初始化
val_iterator = validation_dataset.make_one_shot_iterator()
# make_one_shot_iterator不需要初始化,根据需要不停循环
train_images, train_labels = train_iterator.get_next()
val_images, val_labels = val_iterator.get_next()
for epoch in range(NUM_EPOCHS):
print('Starting epoch %d / %d' % (epoch + 1, NUM_EPOCHS))
sess.run(train_iterator.initializer)
while True:
try:
train_batch_images, train_batch_labels \
= sess.run([train_images, train_labels])
_, train_loss, train_acc = sess.run([fc8_train_op,loss, accuracy],
feed_dict={is_training: True,
images: train_batch_images,
labels: train_batch_labels})
val_batch_images, val_batch_label = \
sess.run([val_images, val_labels])
val_loss, val_acc = sess.run([loss, accuracy],
feed_dict={is_training: False,
images: val_batch_images,
labels: val_batch_label})
# step = sess.run(global_step)
# print("global_step:{0}".format(step))
print("epoch:{0}, train loss:{1},train-acc:{2}".format(epoch, train_loss, train_acc))
print("epoch:{0}, val loss:{0},val-acc:{1}".format(epoch, val_loss, val_acc))
except tf.errors.OutOfRangeError:
break