对字节型数据集cifar10进行可视化
编程环境
ubuntu 16.04
python 2.7
tensorflow-gpu 1.4
代码
cifar10_input.py文件代码
# coding:utf-8
# 绝对引入
from __future__ import absolute_import
# 导入精确除法
from __future__ import division
# 导入print函数,print要使用括号
from __future__ import print_function
import os
import tensorflow as tf
# 定义read.cifar10函数,将数据弄成图像形式
def read_cifar10(filename_queue):
# 定义类
class CIFAR10Record(object):
pass
result = CIFAR10Record()
# 标签字节数
label_bytes = 1
# 图片的高
result.height = 32
# 图片的宽
result.width = 32
# 图片的深度
result.depth = 3
# 一张图片的字节数
image_bytes = result.height * result.width * result.depth
# 标签加图片的字节数
record_bytes = label_bytes + image_bytes
# 定义每次读多少字节
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
# 读取文件队列名中的数据
result.key, value = reader.read(filename_queue)
# tf.decode_raw函数是将原来编码为字符串类型的变量重新变回来
record_bytes = tf.decode_raw(value, tf.uint8)
# tf.strided_slice(input_, begin, end)提取张量的一部分
result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]),
[result.depth, result.width, result.height])
# transpose函数是将矩阵进行转置操做,[0, 1, 2]中0表示高(深度),1表示行,2表示列。
# [1, 2, 0]表示将原来的行作为高,列作为行,高作为列,对矩阵进行转置。
# [1, 2, 0]刚好将矩阵变为了 行×列×深度 的矩阵。
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
return result
cifar10_extract.py文件代码
# coding:utf-8
import tensorflow as tf
import os
import scipy.misc
import cifar10_input
def input_origin(data_dir):
# 读入训练图像
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)]
# 判断文件是否存在
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: '+ f)
# 创建文件名队列
filename_queue = tf.train.string_input_producer(filenames)
# 读取数据
read_input = cifar10_input.read_cifar10(filename_queue)
# 将图片转换为实数形式
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
# 每调用一次sess.run,就会取出一张图片
return reshaped_image
if __name__ == '__main__':
with tf.Session() as sess:
reshaped_image = input_origin('cifar10_data/cifar-10-batches-bin')
# input_origin函数已经产生了文件名队列,现在直接start
threads = tf.train.start_queue_runners(sess=sess)
# 变量初始化
sess.run(tf.global_variables_initializer())
# 创建文件夹
if not os.path.exists('cifar10_data/raw'):
os.makedirs('cifar10_data/raw')
for i in range(30):
image_array = sess.run(reshaped_image)
# scipy.misc.toimage为将一个numpys数组保存为图像
scipy.misc.toimage(image_array).save('cifar10_data/raw/%d.jpg' % i)
代码说明
首先建立一个文件夹,例如CIFAR-10,在里面放入已经下载好的CIFAR10数据集,并且提取出来,可以得到图2中的cifar-10-batches-bin文件夹。再放入编写好的cifar10_input.py文件和cifar_extract.py文件。
cifar10_data数据集里面的数据如图3所示。
实验结果
最终代码会提取出30张(可自行调整)图片放在cifar10_data文件夹下raw文件夹里面。