Tensorflow读取数据-tf.data.TFRecordDataset

tensorflow TFRecords文件的生成和读取方法


备注:本文参考https://zhuanlan.zhihu.com/p/31992460( 公众号《人工智能技术干货》,专注深度学习与计算机视觉!)

​ 本文参考https://www.tensorflow.org/api_docs/python/tf/io/TFRecordWriter (Tensorflow官方教程)

​ 本文参考https://zhuanlan.zhihu.com/p/43356309 【0.1】Tensorflow踩坑记之tf.data

​ 本文参考https://zhuanlan.zhihu.com/p/30751039 TensorFlow全新的数据读取方式:Dataset API入门教程

https://blog.csdn.net/qq_36556054/article/details/102872885?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-0.pc_relevant_default&spm=1001.2101.3001.4242.1&utm_relevant_index=3

1. TFRecords说明

TFRecords是一种tensorflow的内定标准文件格式,其实质是二进制文件,遵循protocol buffer(PB)协议,其后缀一般为tfrecord。TFRecords文件方便复制和移动,能够很好的利用内存,无需单独标记文件,适用于大量数据的顺序读取,是tensorflow“从文件里读取数据”的一种官方推荐方法!本篇文章,我将整理tensorflow TFRecords文件生成和读取的方法, 分为两个部分,每个部分分别介绍并附带例程!

TFrecord 是TensorFlow使用的一种数据格式,他可以把多个训练的图片许多信息压缩在一个文件中,用特殊的方式存储和读取,通过tf.dataset 这个API进行快速的读取和写入。具体使用官方教材,参考官方文档,里面有具体的使用方法,最近又出了高阶的使用方法,等流程跑通了再继续优化,Tensflow公众号-tf.data API,构建高性能 TensorFlow 输入管道。

将数据处理成TFRecord的形式,是tensorflow官方推荐的一种文件可以。使用这种文件格式是官方推荐,具体原因如下:

  • 该文件格式方便复制和移动,能够很好的利用内存;
  • 支持String,Float,Int类型的数据,方便存储结构化的标注数据;

2.关键API

2.1 tf.io.TFRecordWriter类

把记录写入到TFRecords文件的类.

__init__(path,options=None)

作用:创建一个TFRecordWriter对象,这个对象就负责写记录到指定的文件中去了.
参数:
path: TFRecords 文件路径
options: (可选) TFRecordOptions对象

close()

作用:关闭对象.

write(record)

作用:把字符串形式的记录写到文件中去.
参数:
record: 字符串,待写入的记录

2.2 tf.train.Example

tf.Example 的数据类型

这个类是非常重要的,在Tensorflow中样本数据的序列化保存一般采用tfrecord的文件格式,其根本原因在文章开头就已经描述。在tfrecord文件中,实质上是一堆tf.train.Example的集合。

tf.train.Example的数据集包含三个类:tf.train.Feature -> tf.train.Features -> tf.train.Example三个类。具体的关系如下:

example = tf.train.Example(features=tf.train.Features(feature={...}))

具体例子如下:

record_bytes = tf.train.Example(features=tf.train.Features(feature={
        "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
        "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
    }))

上述描述,大家对Example这个类有一个基本的了解。

属性:

features : 是一个tf.train.Features

函数:

__init__(**kwargs)

这个函数是初始化函数,会生成一个Example对象,一般我们使用的时候,是传入一个tf.train.Features对象进去.

SerializeToString()

作用:把example序列化为一个字符串,因为在写入到TFRcorde的时候,write方法的参数是字符串的.

2.3 tf.train.Features

作用:初始化Features对象,一般我们是传入一个字典,字典的键是一个字符串,表示名字,字典的值是一个tf.train.Feature对象.

tf.train.Features(feature ={
	"x":tf.train.Feature(bytes_list=tf.train.BytesList=(value=[])),
	"y":tf.train.Feature(int64_list=tf.train.Int64List(value = []))
})

2.4 tf.train.Feature

包含属性:

bytes_list 对应的对象是:tf.train.BytesList(value=[])
float_list 对应的对象是:tf.train.FloatList(value=[])
int64_list 对应的对象是:tf.train.Int64List(value=[])

2.5 tf.train.Int64List, tf.train.BytesList, tf.train.FloatList

这三个数据类型

  1. tf.train.BytesList(可强制转换自以下类型)
  • string
  • byte
  1. tf.train.FloatList(可强制转换自以下类型)
  • float (float32)
  • double (float64)
  1. tf.train.Int64List(可强制转换自以下类型)
  • bool
  • enum
  • int32
  • uint32
  • int64
  • uint6

2.6 tf.io.TFRecordReader类

To create an input pipeline, you must start with a data source. For example, to construct a Dataset from data in memory, you can use tf.data.Dataset.from_tensors() or tf.data.Dataset.from_tensor_slices(). Alternatively, if your input data is stored in a file in the recommended TFRecord format, you can use tf.data.TFRecordDataset().

将原始的特征数据处理成结构化的tfrecord数据集。

feature -> Features -> Example, 三者按顺序为包含关系

example = tf.train.Example(features=tf.train.Features(feature={...}))
tf_writer.write(example.SerializeToString())  # 序列化写入tfrecord

2.7 tf.data.TFRecordDataset类

Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服务于数据读取,构建输入数据的pipeline。

对Dataset中的元素做变换:Transformation

**Dataset支持一类特殊的操作:Transformation。一个Dataset通过Transformation变成一个新的Dataset。**通常我们可以通过Transformation完成数据变换,打乱,组成batch,生成epoch等一系列操作。

常用的Transformation有:

  • map
  • batch
  • shuffle
  • repeat

下面就分别进行介绍。

(1)map

map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加1:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0
(2)batch

batch就是将多个元素组合成batch,如下面的程序将dataset中的每个元素组成了大小为32的batch:

dataset = dataset.batch(32)
(3)shuffle

shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小:

dataset = dataset.shuffle(buffer_size=10000)
(4)repeat

repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch:

dataset = dataset.repeat(5)

如果直接调用repeat()的话,生成的序列就会无限重复下去,没有结束,因此也不会抛出tf.errors.OutOfRangeError异常:

dataset = dataset.repeat()

3.写操作

下面以两个例子来说明一下写操作。

例子1:鸢尾花数据集-写操作

第一个例子是将鸢尾花的数据集iris.csv,将这部分数据写入到iris.tfrecords中。

import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("input","../data/iris.csv","输入")
flags.DEFINE_string("output","../data/iris.bin","输入")

def main(_):
    inputfile = FLAGS.input
    outputfile = FLAGS.output
    write = tf.io.TFRecordWriter(outputfile)
    idx = -1
    with tf.io.gfile.GFile(inputfile,'r') as reader:
        for line in reader:
            idx += 1   #为了跳过第一行的description
            if idx < 1 :
                continue
            splits = line.split(",")
            x = [float(i) for i in splits[:-1]]
            y = [int(splits[-1])]
            example = tf.train.Example(
                features=tf.train.Features(
                    feature={
                        'x': tf.train.Feature(
                            float_list=tf.train.FloatList(value=x)  # 方括号表示输入为list,一般tf.train.FloatList被用来处理浮点数
                        ),
                        'y': tf.train.Feature(
                            int64_list=tf.train.Int64List(value=y)  # B_data本身就是列表,一般tf.train.Int64List被用来处理整数
                        )
                    }
                )
            )
            print(example.SerializeToString())
            write.write(example.SerializeToString())
    write.close()

if __name__ == "__main__":
    tf.app.run()

例子2:


4.读操作

在TensorFlow 1.3中,Dataset API是放在contrib包中的:

tf.contrib.data.Dataset

而在TensorFlow 1.4中,Dataset API已经从contrib包中移除,变成了核心API的一员:

tf.data.Dataset

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RmKWUi95-1643547854392)(…/…/…/Images/TF-DataSet.png)]

例子1 鸢尾花数据集-读操作

读取数据中,

import tensorflow as tf

flags = tf.flags

FLAGS = flags.FLAGS


flags.DEFINE_string("input" , "../data/iris.bin" , "input file")

def decord_fn(dataset):
    features = tf.io.parse_single_example(
    dataset ,
    {
        'x':tf.io.FixedLenFeature([4] , tf.float32 ) ,
        'y':tf.io.FixedLenFeature([1] , tf.int64 )
    })
    return features

def main(_):
    dataset = tf.data.TFRecordDataset(FLAGS.input).map(decord_fn)
    dataset = dataset.batch(10)
    NUM_EPOCHS = 2
    dataset = dataset.repeat(NUM_EPOCHS)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    with tf.Session() as sess:
        sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
        #sess.run(iterator.initializer)
        while True:
            try:
                x_train = sess.run(next_element)  # 取出设定的
                print(x_train)
                # print("x = {x:.4f},  y = {y:.4f}".format(x_train['x'],x_train['y']))
            except tf.errors.OutOfRangeError:
                break
if __name__ == "__main__":
    #flags.mark_flag_as_required("input")
    #flags.mark_flag_as_required("output")
    tf.app.run()
import tensorflow as tf

flags = tf.flags

FLAGS = flags.FLAGS


flags.DEFINE_string("input" , "../data/iris.bin" , "input file")

def decord_fn(dataset):
    features_dict  = {
        "x":tf.io.FixedLenFeature([4] , tf.float32 ),
        "y":tf.io.FixedLenFeature([1] , tf.int64)
    }
    results = tf.io.parse_single_example(dataset , features=features_dict)
    return  results
def main(_):
    dataset = tf.data.TFRecordDataset(FLAGS.input)
    dataset = dataset.map(decord_fn)
    dataset = dataset.batch(10)
    NUM_EPOCHS = 2
    dataset = dataset.repeat(NUM_EPOCHS)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    global_init = tf.global_variables_initializer()
    local_init = tf.local_variables_initializer()
    with tf.Session() as sess:
        sess.run([global_init , local_init])
        sess.run(iterator.initializer)
        while True:
            try:
                out = sess.run(next_element)
                print(out)
            except tf.errors.OutOfRangeError:
                break
if __name__ == "__main__":
    #flags.mark_flag_as_required("input")
    #flags.mark_flag_as_required("output")
    tf.app.run()

5.一些问题

在使用tfrecords的时候主要遇到一些问题,主要包括:

一是处理大量数据的的时候文件太大(一万张图片有700M),并且生成的太慢,

第二就是训练的时候将图片value和label打印出来,发现不对应。 但写入的时候label和image并没有错乱,也排查了读取的代码以及数据类型的问题,但是的确会出现image和label不对应的问题,我想问一下您在制作且用于训练的时候显示过吗,出现过这种情况吗?能够确定自己的image和label是对应的吗

  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值