原文地址: http://blog.csdn.net/u010911921/article/details/70991194
上篇博客谈到了Tensorflow从文件中读取数据,当时采用的是CIFAR-10中的二进制数据,这次记录一下官网推荐的比较通用和高效的数据文件类型的读取——TFRecord文件,这是tensorflow指定的标准格式。
1.TFRecords
TFRecords本质上是一种二进制文件,他的优点是可以更好的利用内存空间,缺点是生成过程比较耗费时间,特别是数据量比较大的情况下。文件包含了一个tf.train.Example
的缓冲协议(protocol buffer)其中协议块中包含了字段Features
.当用程序获得数据以后,就可以将其填充到Example
的协议缓冲区(protocol buffer)中,然后在将协议缓冲区序列化为字符串,最后通过tf.python_io.TFRecordWriter
将字符串写入文件。
当从TFRecords文件中读取数据时,可以利用tf.TFRecordReader
和tf.parse_single_example
解码器,将Example
缓冲协议中的内容解析为Tensor
张量
2.notMNIST 数据集
在实验中采用的数据集合时notMNIST数据集,这个数据集合是由一些各种形态的字母组成的数据集合,总共由a~j
10个字母组成,下图是a
对应的一些图片:
另外需要注意的是,下载的数据集中有几张图片有损坏,所以处理的时候注意跳过。
3.生成TFRecords文件
为了生成TFRecords文件首先是从数据集中,将图片路径放置到一个image_list
,样本的标签放置到一个label_list
中。
#!/usr/bin/env python3
# --*-- encoding:utf-8 --*--
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
import skimage.io as io
def get_file(file_dir):
"""
get full image directory and correspond labels
:param file_dir:
:return:
"""
images =[]
temp =[]
for root ,sub_folders,files