tf.data学习指南(超实用、超详细)

tf.data是一个极其好用的数据读取管道搭建的API,甚至在tensorflow2.0中完全代替了诸如队列等其他方法。
使用该API构建数据管道,主要依靠两个API:

  • tf.data.Dataset
  • tf.data.Iterator

tf.data.Dataset用于读入数据,做预处理,调整batch和epoch等常规操作,而读取数据则依赖于Iterator接口。这二者的关系和Pytorch中的Dataset和DataLoader很相似。
另外,由于Tensorflow2和tensorflow1代的特性不一样,本文将分成两个部分讲解tf.data的在两个版本的用法。

公共的基础知识

不管是在什么版本,1代和2代在构建数据集(dataset)上都是一样的,不一样的地方在于tf.data.Iterator如何使用,即读取数据的方式不一样。所以这一节将介绍如何搭建一个tf.data.Dataset模块。

  • 已知文件名称和标签,用data保存每一个文件的地址,用label保存每一文件对应的标签。data和label都是列表,形式如data = [‘xxxx.jpg’,‘qqqq.jpg’,…]; label = [0,2,3,4,1,…]
import tensorflow as tf
import tensorflow.contrib.eager as tfe
import os
# tf.enable_eager_execution()

file_path = r'E:\dataset\DAVIS\JPEGImages\480p\bear'
data= [os.path.join(file_path,i) for i in os.listdir(file_path)]
label = [0]*len(data)
print(data)
print(len(label))

#['E:\\dataset\\DAVIS\\JPEGImages\\480p\\bear\\00000.jpg', 'E:\\dataset\\DAVIS\\JPEGImages\\480p\\bear\\00001.jpg', 
......]
#82

通过上面的代码,我们得到了训练数据的图像地址,至于label,因为我只是做示范,所以就生成了全0的label,长度和读取到的图像地址一致。

先看一种最简单也是最常用的方式,使用tf.data.Dataset.from_tensor_slices

dataset = tf.data.Dataset.from_tensor_slices((data,label))
print(datset)
#  <DatasetV1Adapter shapes: ((), ()), types: (tf.string, tf.int32)>

这就算是构建了一个dataset了。tf.data.Dataset类有几个函数会经常使用到。

  • batch()
  • repeat()
  • shuffle()
  • map()
  • zip() // 以后再介绍,用的不多,但功能强大。

batch():用一个整型数字作为参数,描述了一个batch的batch size。
repeat():参数同样是一个整型数字,描述了整个dataset需要重复几次(epoch),如果没有参数,则重复无限次。
shuffle():顾名思义
map():常常用作预处理,图像解码等操作,参数是一个函数句柄,dataset的每一个元素都会经过这个函数的到新的tensor代替原来的元素。

接下来就要结合Iterator来理解以上4个方法的使用。

在Tensorflow1.x中构建管道

使用Iterator来得到数据集dataset类型的数据接口。Dataset类型提供直接生成迭代器的函数:

  • tf.data.Dataset.make_one_shot_iterator()
  • tf.data.Dataset.make_initializable_iterator()
  • make_one_shot_iterator不需要用户显示地初始化,但是仅仅能迭代(遍历)一次数据集。
  • make_initializable_iterator需要用户显示地初始化,并且在初始化时可以送入之间定义的placeholder,达到更加灵活的使用。

迭代器有get_next方法,来获得数据的tensors。该tensors的含义就是从构建数据集所用的from_tensors_slice的参数形式。

下面看两个例子。

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import os
# tf.enable_eager_execution()  # 先用一代函数

file_path = r'E:\dataset\DAVIS\JPEGImages\480p\bear'
data = [os.path.join(file_path,i) for i in os.listdir(file_path)]
label = [0]*len(data)
# print(data)
# print(len(label))
dataset = tf.data.Dataset.from_tensor_slices((data,label))
# print(dataset)
iterator = dataset.make_one_shot_iterator()
img_name, label = iterator.get_next()

with tf.Session() as sess:
    while 1:
        try:
            name, num = sess.run([img_name,label])
            print(name)
            assert num == 0, "fail to read label"
        except tf.errors.OutOfRangeError:
            print("iterator done")
            break

# b'E:\\dataset\\DAVIS\\JPEGImages\\480p\\bear\\00000.jpg'
# b'E:\\dataset\\DAVIS\\JPEGImages\\480p\\bear\\00001.jpg'
# b'E:\\dataset\\DAVIS\\JPEGImages\\480p\\bear\\00002.jpg'
# .....
# iterator done

我们发现输出的是图像名称,因为我们送到dataset的内容就是图像的路径信息,但是现在想读取图像数据。这就需要用到map方法了。
先自定义一个读取图像的函数:

def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_jpeg(image_string,channels=3)
  image_resized = tf.image.resize_images(image_decoded, [224, 224])
  return image_resized, label
file_path = r'E:\dataset\DAVIS\JPEGImages\480p\bear'
data = [os.path.join(file_path,i) for i in os.listdir(file_path)]
label = [0]*len(data)
# print(data)
# print(len(label))
dataset = tf.data.Dataset.from_tensor_slices((data,label))
dataset = dataset.map(_parse_function)   # 在这里加入
# print(dataset)
iterator = dataset.make_one_shot_iterator()
img, label = iterator.get_next()

with tf.Session() as sess:
    while 1:
        try:
            image, num = sess.run([img,label])
            print(image.shape)
            assert num == 0, "fail to read label"
        except tf.errors.OutOfRangeError:
            print("iterator done")
            break


# ....
# (224, 224, 3)
# (224, 224, 3)
# (224, 224, 3)
# iterator done

接着我们使用dataset的batch方法,每一次读取一个batch。

dataset = tf.data.Dataset.from_tensor_slices((data,label))
dataset = dataset.map(_parse_function)
dataset = dataset.batch(5)

# 其他不变,输出是
(5, 224, 224, 3)
(5, 224, 224, 3)
(2, 224, 224, 3)
iterator done
一共82张图像,所以最后剩2张图像组合为一个batch。

接着我们使用repeat把数据集重复两次。然后用一个全局变量count记录总数,看看是不是等于 82*2 = 164。

file_path = r'E:\dataset\DAVIS\JPEGImages\480p\bear'
data = [os.path.join(file_path,i) for i in os.listdir(file_path)]
label = [0]*len(data)
# print(data)
# print(len(label))
dataset = tf.data.Dataset.from_tensor_slices((data,label))
dataset = dataset.map(_parse_function)
dataset = dataset.batch(1)   # 改为1,方便计数
dataset = dataset.repeat(2)   # 数据集重复两次
# print(dataset)
iterator = dataset.make_one_shot_iterator()
img, label = iterator.get_next()

count = 0
with tf.Session() as sess:
    while 1:
        try:
            image, num = sess.run([img,label])
            print(image.shape)
            count += 1
        except tf.errors.OutOfRangeError:
            print("iterator done")
            print("count is ",count)  # 打印conut 作为验证
            break

# ....
(1, 224, 224, 3)
(1, 224, 224, 3)
iterator done
count is  164

OK,确实如我们所想。

再来说下tf.data.Dataset.make_initializable_iterator()
该方法返回了一个可初始化的迭代器。用户需要使用sess.run(iterator.initializer)显示初始化。

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import os
# tf.enable_eager_execution()
from read import _parse_function as _parse_function
file_path = r'E:\dataset\DAVIS\JPEGImages\480p\bear'
data = [os.path.join(file_path,i) for i in os.listdir(file_path)]
label = [0]*len(data)
dataset = tf.data.Dataset.from_tensor_slices((data,label))
dataset = dataset.map(_parse_function)
batch_size = tf.placeholder(tf.int64,shape=[])
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(1)
iterator = dataset.make_initializable_iterator()
img, label = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer, feed_dict={batch_size: 20})
    while 1:
        try:
            image, num = sess.run([img,label])
            print(image.shape)
        except tf.errors.OutOfRangeError:
            print("second iterator done")
            break
    sess.run(iterator.initializer, feed_dict={batch_size: 40})   # 送入batchsize的确切值
    while 1:
        try:
            image, num = sess.run([img,label])
            print(image.shape)
        except tf.errors.OutOfRangeError:
            print("first iterator done")
            break
# 输出是
# (20, 224, 224, 3)
(20, 224, 224, 3)
(20, 224, 224, 3)
(20, 224, 224, 3)
(2, 224, 224, 3)
first iterator done
(40, 224, 224, 3)
(40, 224, 224, 3)
(2, 224, 224, 3)
second iterator done

可以看出,我第一次在初始化迭代器的时候,送入了batchsize 是20,于是第一次迭代的batch就是20个样本,一共82个没毛病。第二次初始化迭代器的时候,batchsize设置的是40,于是第二次迭代的batch就是40个样本,一共82个样本,也没毛病。

Note:tf.data还有可重新初始化迭代器和可feeding迭代器。因为基本上可以被以上两种迭代器在功能上代替,我就不介绍了。更详细的介绍可以参见here

另外迭代器的状态是可以保存的,比如训练到一半,样本读到哪里了,把这个状态保存起来,下一次运行导入模型,之后接着上一次的状态继续运行。

 saveable_iter  = tf.data.experimental.make_saveable_from_iterator(iterator)
 tf.compat.v1.add_to_collections(tf.GraphKeys.SAVEABLE_OBJECTS,saveable_iter)
 saver = tf.train.Saver()
 with tf.Session() as sess:
 	.....
 	saver.save(path_to_checkpoint)
 
 # restore the state of iterator
 with tf.Session() as sess:
 	saver.restore(sess,path_to_checkpoint)

在Tensorflow2.0中构建管道

TF2.0支持和PyTroch一样的eager模式,所以在该模式下,session和placehoder被弃用。这种情况下,获取数据的读取接口仅仅和tf1.x有一点区别,而且更加简单好用。

我们在建利dataset之后,仅仅用python内置函数iter即可。
例子:

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import os
tf.enable_eager_execution()   # 开起eager模式,博主用的1.13.1,如果使用的是2.0,则忽略这行
from read import _parse_function as _parse_function

file_path = r'E:\dataset\DAVIS\JPEGImages\480p\bear'
data = [os.path.join(file_path,i) for i in os.listdir(file_path)]
label = [0]*len(data)
dataset = tf.data.Dataset.from_tensor_slices((data,label))
dataset = dataset.map(_parse_function)
dataset = dataset.batch(40)
dataset = dataset.repeat(1)
iterator = iter(dataset)

while 1:
    try:
        image, _ = next(iterator)
        print(image.shape)
    except StopIteration:   # python内置的迭代器越界错误类型
        print("iterator done")
        break;
        
# 输出是
(40, 224, 224, 3)
(40, 224, 224, 3)
(2, 224, 224, 3)
iterator done

我们也可以使用
import tensorflow.contrib.eager as tfe所支持的Iterator类生成迭代器

....   # 和上面一样
iterator = tfe.Iterator(dataset)

for img, _ in iterator:
    print(img.shape)

# (40, 224, 224, 3)
(40, 224, 224, 3)
(2, 224, 224, 3)

参考链接

[0.2] Tensorflow踩坑记之头疼的tf.data
TensorFlow全新的数据读取方式:Dataset API入门教程
GitHub上一个例子
cs230课程
TensorFlow tf.data 导入数据(tf.data官方教程)

  • 18
    点赞
  • 83
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值