TensorFlow高级API系列(三):Dataset API

前言

tf.data非常的好用,这里不多说,如果你停留在placeholder,feed_dict,你可能对这篇博客并不感兴趣。如果在处理大规模数据,tf.data就极其好用了。

从内存里面读取数据

我们先放代码,再慢慢解读

import tensorflow as tf
import numpy as np

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))

#如何取出数据呢?
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    for i in range(5):
        print(sess.run(one_element))

从代码中我们可以看出,one_element本质上还是个tensorflow,需要sess才能打印出结果。

tf.data.Dataset.from_tensor_slices的功能不止如此,它的真正作用是切分传入Tensor的第一个维度,生成相应的dataset。

dataset = tf.data.Dataset.from_tensor_slices(np.random.uniform(size=(5, 2)))

传入的数值是一个矩阵,它的形状为(5, 2),tf.data.Dataset.from_tensor_slices就会切分它形状上的第一个维度,最后生成的dataset中一个含有5个元素,每个元素的形状是(2, ),即每个元素是矩阵的一行。

从dict中构建dataset

dataset = tf.data.Dataset.from_tensor_slices(
    {
   
        "a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),
        "b": np.random.uniform(size=(5, 2))
    }
)

这时函数会分别切分"a"中的数值以及"b"中的数值,最终dataset中的一个元素就是类似于{“a”: 1.0, “b”: [0.9, 0.1]}的形式。

从文件中读取数据

大部分时间,我们是需要从文件中读取数据的,不可能总是从内存里面读取数据。这也是tf.data设计的初衷。目前Dataset API提供了三种从文件读取数据并创建Dataset的方式,分别用来读取不同存储格式的文件

在这里插入图片描述
常用的两个接口是前两个。
tf.data.TextLineDataset():这个函数的输入是一个文件的列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。可以使用这个函数来读入CSV文件。

tf.data.TFRecordDataset():顾名思义,这个函数是用来读TFRecord文件的,dataset中的每一个元素就是一个TFExample。

tf.data.FixedLengthRecordDataset():这个函数的输入是一个文件的列表和一个record_bytes,之后dataset的每一个元素就是文件中固定字节数record_bytes的内容。通常用来读取以二进制形式保存的文件,如CIFAR10数据集就是这种形式。这个接口我没有使用过,之后有机会再补充。

后面会有代码详细介绍这几个接口的使用。这里的接口都是可以直接读取HDFS上的数据的。

DataSet的常用变换

一个Dataset通过数据变换操作可以生成一个新的Dataset。下面介绍数据格式变换、过滤、数据打乱、生产batch和epoch等常用Transformation操作。

(1)map操作
这个操作很有用,基本读数据都会用到。
map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值取平方:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x * x) # 1.0, 4.0, 9.0, 16.0, 25.0

(2)filter操作
过滤操作
filter操作可以过滤掉dataset不满足条件的元素,它接受一个布尔函数作为参数,dataset中的每个元素都作为该布尔函数的参数,布尔函数返回True的元素保留下来,布尔函数返回False的元素则被过滤掉。

dataset = dataset.filter(filter_func)

(3)shuffle
shuffle功能为打乱dataset中的元素,它有一个参数buffer_size,表示打乱时使用的buffer的大小:

dataset = dataset.shuffle(buffer_size=10000)

(4)repeat
repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch:

dataset = dataset.repeat(5)

如果直接调用repeat()的话,生成的序列就会无限重复下去,没有结束,因此也不会抛出tf.errors.OutOfRangeError异常。很多代码会直接使用这个,主要原因是训练步数已经设置好了,数据可以一直重复

(5)batch
batch就是将多个元素组合成batch,如下面的程序将dataset中的每个元素组成了大小为32的batch:

dataset = dataset.batch(32)

需要注意的是,必须要保证dataset中每个元素拥有相同的shape才能调用batch方法,否则会抛出异常。在调用map方法转换元素格式的时候尤其要注意这一点。

实战代码

这里我会人造一些数据来演示代码。

解析csv文件
这种文件格式在我们平时做数据处理的时候经常遇到。
一般会使用tf.decode_csv来处理。我们先看一下这个接口的接受的一些参数,这个能够帮你方便的处理一些特殊情况。

大家可以先看一下官方api文档,相信会有帮助https://tensorflow.google.cn/api_docs/python/tf/compat/v1/decode_csv?hl=en

columns_name = ["field1", "field2", "field3", "label"]

columns_default 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值