【tensorflow:实战Google深度学习框架】-tip3 tf.data.Dataset

参考链接:https://blog.csdn.net/lyb3b3b/article/details/82910863
tf.data.Dataset的API导入
在tf 1.3.0及其以下版本中,Dataset API是放在contrib包中的:

tf.contrib.data.Dataset

从tf 1.4.0开始该API独立出来:

tf.data.Dataset

一、基本概念:Dataset与Iterator

在这里插入图片描述
在初学时,我们只需要关注两个最重要的基础类:Dataset和Iterator。

Dataset可以看作是相同类型“元素”的有序列表。在实际使用时,单个“元素”可以是向量,也可以是字符串、图片,甚至是tuple或者dict。
from_tensor_slices是这个Dataset类的一个方法

  • 先tf.data.Dataset.from_tensor_slices产生数据集Dataset,经过实例化,才产生迭代器Iterator。

  • 注意迭代器Iterator分为:

iterator = dataset.make_one_shot_iterator() 只能读一次

iterator = dataset.make_initializable_iterator() (这个需要iterator.initializer初始化

import tensorflow as tf
import numpy as np

'''创建dataset'''
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))

'''实例化iterator'''
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    for i in range(5):
        print(sess.run(one_element))  # 则输出1.0 2.0 3.0 4.0 5.0
# 或者
# 不过,make_initializable_iterator的情况需要初始化
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    # 注意:这里多了一个初始化,
    sess.run(iterator.initializer)
    try:
        while True:
            print(sess.run(next_element))
    except tf.errors.OutOfRangeError:
        print('end')
  • 使用dataset=tf.data.Dataset.from_tensor_slices创建dataset
  • iterator = dataset.make_one_shot_iterator()从dataset中实例化了一个Iterator
  • one_element = iterator.get_next()表示从iterator里取出一个tensor
  • 使用sess.run,才可以获取真正的值

二、高维数据集使用

tf.data.Dataset.from_tensor_slices真正的作用是切分传入Tensor的第一个维度,生成相应的dataset。

dataset = tf.data.Dataset.from_tensor_slices(np.random.uniform(size=(5, 2)))

传入的数值是一个矩阵,它的形状为(5, 2),tf.data.Dataset.from_tensor_slices就会切分它形状上的第一个维度,最后生成的dataset中一个含有5个元素,每个元素的形状是(2, ),即每个元素是矩阵的一行。

在实际使用中,我们可能还希望Dataset中的每个元素具有更复杂的形式,如每个元素是一个Python中的元组,或是Python中的词典
例如,输入是训练集和标签的tuple,生成的每条记录也是tuple

dataset = tf.contrib.data.Dataset.from_tensor_slices(
  ( np.random.uniform(size=(5, 2)), np.array([1.0, 2.0, 3.0, 4.0, 5.0])))
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            print(sess.run(one_element))
    except tf.errors.OutOfRangeError:
        print("end!")
输出:
(array([6.55877282e-04, 6.63244735e-01]),1.0)
(array([0.04756927, 0.44968581]),2.0)
(array([0.97841076, 0.06465231]),3.0)
(array([0.46639246, 0.39146086]),4.0)
(array([0.61085016, 0.61609538]),5.0)

例如,在图像识别问题
一个元素可以是{“image”: image_tensor, “label”: label_tensor}的形式,这样处理起来更方便。
tf.data.Dataset.from_tensor_slices同样支持创建这种dataset,例如我们可以让每一个元素是一个词典。

dataset = tf.contrib.data.Dataset.from_tensor_slices(
    {
        "a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),                                       
        "b": np.random.uniform(size=(5, 2))
    }
)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            print(sess.run(one_element))
    except tf.errors.OutOfRangeError:
        print("end!")
输出:
{'a': 1.0, 'b': array([0.31721037, 0.33378767])}
{'a': 2.0, 'b': array([0.99221946, 0.65894961])}
{'a': 3.0, 'b': array([0.98405468, 0.11478854])}
{'a': 4.0, 'b': array([0.95311317, 0.57432678])}
{'a': 5.0, 'b': array([0.46067428, 0.19716722])}

三、对Dataset中的元素做变换

Dataset支持一类特殊的操作:Transformation。一个Dataset通过Transformation变成一个新的Dataset。通常我们可以通过Transformation完成数据变换,打乱,组成batch,生成epoch等一系列操作。
常用的Transformation有:

  • Map
  • batch
  • shuffle
  • repeat

1 .map

map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加1:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0

2.batch

batch就是将多个元素组合成batch,如下面的程序将dataset中的每个元素组成了大小为2的batch:

dataset = tf.data.Dataset.from_tensor_slices(
    {
        "a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),                                      
        "b": np.random.uniform(size=(5, 2))
    })
dataset = dataset.batch(2) 
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session(config=config) as sess:
    try:
        while True:
            print(sess.run(one_element))
    except tf.errors.OutOfRangeError:
        print("end!")
输出
{'a': array([1., 2.]), 'b': array([[0.87466134, 0.21519021], [0.6123372 , 0.95722733]])}
{'a': array([3., 4.]), 'b': array([[0.76964374, 0.22445015], [0.08313089, 0.60531841]])}
{'a': array([5.]), 'b': array([[0.37901654, 0.3955096 ]])}

3.shuffle

shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小:
dataset = dataset.shuffle(buffer_size=10000)

4.repeat

repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch:dataset = dataset.repeat(5)
如果直接调用repeat()的话,生成的序列就会无限重复下去,没有结束,因此也不会抛出tf.errors.OutOfRangeError异常:dataset = dataset.repeat()

5.读入磁盘图片与对应label

我们可以来考虑一个简单,但同时也非常常用的例子:读入磁盘中的图片和图片相应的label,并将其打乱,组成batch_size=32的训练样本。在训练时重复10个epoch。

# 函数的功能时将filename对应的图片文件读进来,并缩放到统一的大小
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_image(image_string)
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label
 
# 图片文件的列表
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])
# label[i]就是图片filenames[i]的label
labels = tf.constant([0, 37, ...])
 
# 此句后dataset中的一个元素是(filename, label)
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
 
# 此句后dataset中的一个元素是(image_resized, label)
dataset = dataset.map(_parse_function)
 
# 此句后dataset中的一个元素是(image_resized_batch, label_batch)
dataset = dataset.shuffle(buffersize=1000).batch(32).repeat(10)

这个过程中,dataset经历三次转变:

  • 运行dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))后,dataset的一个元素是(filename, label)。filename是图片的文件名,label是图片对应的标签。
  • 之后通过map,将filename对应的图片读入,并缩放为28x28的大小。此时dataset中的一个元素是(image_resized, label)。
  • 最后,dataset.shuffle(buffersize=1000).batch(32).repeat(10)的功能是:在每个epoch内将图片打乱组成大小为32的batch,并重复10次。最终,dataset中的一个元素是(image_resized_batch, label_batch),image_resized_batch的形状为(32, 28, 28, 3),而label_batch的形状为(32, ),接下来我们就可以用这两个Tensor来建立模型了。
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值