转载请注明作者和出处: http://blog.csdn.net/wiinter_fdd/article/details/72835939
运行平台: Windows
Python版本: Python3.x
IDE: Spyder
前言
最近一直在研究深度学习,主要是针对卷积神经网络(CNN),接触过的数据集也有了几个,最经典的就是MNIST, CIFAR10/100, NOTMNIST, CATS_VS_DOGS 这几种,由于这几种是在深度学习入门中最被广泛应用的,所以很多深度学习框架 Tensorflow、keras和pytorch都有针对这些数据集专用的数据导入的函数封装,但是一般情况下我们的数据集并不是这种很规范的形式,那么如何把自己的数据集转换成这些框架能够使用的数据形式至关重要,接下来博主将会对现有的较流行的深度学习框架封装自己的数据进行讲解,首先是针对最流行的Tensorflow。
查阅tensorflow的官方API,在GET STARTED下面的Programmer’s Guide中有一个Reading Data的章节介绍,大体内容就是tensorflow读取数据的方式:
可以看到,tensorflow官网给出了三种读取数据的方法:
对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(tip:使用这种方法时,结合yield 使用更为简洁,大家自己尝试一下吧,我就不赘述了)。但是,如果数据量较大,这样的方法就不适用了,因为太耗内存,所以这时最好使用tensorflow提供的队列queue,也就是第二种方法 从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,在这儿我介绍一种比较通用,高效的读取方法(官网介绍的少),即使用tensorflow内定标准格式——TFRecord.
那下面就让我们了解一下什么是TFRecord:
1. What is TFRecord?
TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件(等会儿就知道为什么了)… …总而言之,这样的文件格式好处多多,所以让我们用起来吧。这里注意:TFRecord会根据你输入的文件的类,自动给每一类打上同样的标签。
TFRecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下面是tf.train.Example的定义:
message Example {
Features features = 1;
};
message Features{
map<string,Feature> featrue = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
从上述代码可以看到,ft.train.Example 的数据结构相对简洁。tf.train.Example中包含了一个从属性名称到取值的字典,其中属性名称为一个字符串,属性的取值可以为字符串(BytesList ),实数列表(FloatList )或整数列表(Int64List )。例如我们可以将解码前的图片作为字符串,图像对应的类别标号作为整数列表。
2. How to convert our own data to TFRecord?
终于我们关心的话题来了,怎么转换?这里我们使用Kaggle上面有名猫狗大战的数据集可以通过Dogs vs Cats来下载,为了方便演示,我们利用这个数据集创建了一个新的数据集,取猫狗图片中各100张分别放在data文件夹下面的cats和dogs子文件中,入下图所示。
数据准备好以后,我们就要开始读取数据&#x