Tensorflow踩坑记之tf.data
今天尝试总结一下 tf.data 这个API的一些用法吧。之所以会用到这个API,是因为需要处理的数据量很大,而且数据均是分布式的存储在多台服务器上,所以没有办法采用传统的喂数据方式,而是运用了 tf.data 对数据进行了相应的预处理,并且最近正赶上总结需要,尝试写一下关于 tf.data 的一些用法,有错误的地方一定告诉我哈。
Tensorflow的数据读取
先来看一下Tensorflow的数据读取机制吧
这一篇文章对于 tensorflow的数据读取机制 讲解得很不错,大噶可以先看一下,有一个了解。
Dataset API是怎么用的呢
虽然上面的资料关于 tf.data 讲解得都很好,但是我没有找到一个很完整滴运用 tf.data.TextLineDataset() 和 tf.data.TFRecordDataset() 的例子,所以才想尝试写一写这篇总结。
MNIST的经典例子
本篇博客结合 mnist 的经典例子,针对不同的源数据:csv数据和tfrecord数据,分别运用 tf.data.TextLineDataset() 和 tf.data.TFRecordDataset() 创建不同的 Dataset 并运用四种不同的 Iterator ,分别是 单次,可初始化,可重新初始化,以及可馈送迭代器 的方式实现对源数据的预处理工作。
我将相关的资料放在了澜子的Github 上,欢迎互粉哇(星星眼)。其中包括了所需的 后缀名为csv和tfrecords的源数据 (data
的文件夹),以及在 jupyter notebook实现的具体代码 (tf_dataset_learn.ipynb
)。
如果有需要的同学可以直接
git clone https://github.com/lanhongvp/tensorflow_dataset_learn.git
然后用 jupyter 跑一跑看看输出,这样可以有一个比较直观的认识。关于 Git和Github 的使用,大噶可以看我VSCODE_GIT这一篇博客啦。接下来,针对MNIST例子做一个简单的说明吧。
tf.data.TFRecordDataset() & make_one_shot_iterator()
tf.data.TFRecordDataset() 输入参数直接是后缀名为tfrecords
的文件路径,正因如此,即可解决数据量过大,导致无法单机训练的问题。本篇博客中,文件路径即为/Users/honglan/Desktop/train_output.tfrecords
,此处是我自己电脑上的路径,大家可以 根据自己的需要修改为对应的文件路径。
make_one_shot_iterator() 即为单次迭代器,是最简单的迭代器形式,仅支持对数据集进行一次迭代,不需要显式初始化。
配合 MNIST数据集以及tf.data.TFRecordDataset(),实现代码如下。
# Validate tf.data.TFRecordDataset() using make_one_shot_iterator()
import tensorflow as tf
import numpy as np
num_epochs = 2
num_class = 10
sess = tf.Session()
# Use `tf.parse_single_example()` to extract data from a `tf.Example`
# protocol buffer, and perform any additional per-record preprocessing.
def parser(record):
keys_to_features = {
"image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
"pixels": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
"label": tf.FixedLenFeature((), tf.int64,