tensorflow可以读取样本长度固定的二进制文件,比如CIFAR-10数据,该二进制数据中一个样本由1字节的label和32*32*3字节的image组成。TFRecords是tensorflow设计的一种内置的文件格式,是一种二进制文件,它能更好地利用内存,更方便复制和移动。
该程序实现tensorflow首先读取CIFAR-10的二进制数据,然后将其保存成tfrecords格式的文件,最后实现对tfrecords文件的读取。
一 Tensorflow读取二进制文件
1、构造文件队列
file_queue = tf.train.string_input_producer(file_list) # file_list:文件列表
2、构建二进制文件阅读器,读取内容(读取一个样本字节大小)
reader = tf.FixedLengthRecordReader(bytes_length) # bytes_length:一个样本字节大小
key, value = reader.read(file_queue)
3、解码内容,二进制文件中读取为uint8格式
label_image = tf.decode_raw(value, tf.uint8)
4、分割出特征值和标签值
label = tf.cast(tf.slice(label_image, [0], [1]), tf.int32)
image = tf.slice(label_image, [1], [3073])
5、对图片的特征值数据进行形状改变[3072]-->[32, 32,3],方便后面批处理
image_reshape = tf.reshape(image, [32,32,3])
6、批处理数据
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
二 将数据写入TFRecords文件
1、建立TFRecords存储器
writer = tf.python_io.TFRecordWriter(save_path) # save_path:保存路径+文件名
2、循环将所有样本写入文件,每个样本都要构造example