#-*- coding:utf-8 -*-
__author__ = 'Leo.Z'
#Tensorflow Version:1.14.0
importosimporttensorflow as tffrom PIL importImage
BATCH_SIZE= 128
defread_cifar10(filenames):
label_bytes= 1height= 32width= 32depth= 3image_bytes= height * width *depth
record_bytes= label_bytes +image_bytes#lamda函数体
#def load_transform(x):
## Convert these examples to dense labels and processed images.
#per_record = tf.reshape(tf.decode_raw(x, tf.uint8), [record_bytes])
#return per_record
#tf v1.14.0版本的FixedLengthRecordDataset(filename_list,bin_data_len)
datasets = tf.data.FixedLengthRecordDataset(filenames=filenames, record_bytes=record_bytes)#是否打乱数据
#datasets.shuffle()
#重复几轮epoches
datasets = datasets.shuffle(buffer_size=BATCH_SIZE).repeat(2).batch(BATCH_SIZE)#使用map,也可使用lamda(注意,后面使用迭代器的时候这里转换为uint8没用,后面还得转一次,否则会报错)
#datasets.map(load_transform)
#datasets.map(lamda x : tf.reshape(tf.decode_raw(x, tf.uint8), [record_bytes]))
#创建一起迭代器tf v1.14.0
iter =tf.compat.v1.data.make_one_shot_iterator(datasets)#获取下一条数据(label+image的二进制数据1+32*32*3长度的bytes)
rec =iter.get_next()#这里转uint8才生效,在map中转貌似有问题?
rec =tf.decode_raw(rec, tf.uint8)
label=tf.cast(tf.slice(rec, [0, 0], [BATCH_SIZE, label_bytes]), tf.int32)#从第二个字节开始获取图片二进制数据大小为32*32*3
depth_major =tf.reshape(
tf.slice(rec, [0, label_bytes], [BATCH_SIZE, image_bytes]),
[BATCH_SIZE, depth, height, width])#将维度变换顺序,变为[H,W,C]
image = tf.transpose(depth_major, [0, 2, 3, 1])#返回获取到的label和image组成的元组
return(label, image)defget_data_from_files(data_dir):#filenames一共5个,从data_batch_1.bin到data_batch_5.bin
#读入的都是训练图像
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' %i)for i in range(1, 6)]#判断文件是否存在
for f infilenames:if nottf.io.gfile.exists(f):raise ValueError('Failed to find file:' +f)#获取一张图片数据的数据,格式为(label,image)
data_tuple =read_cifar10(filenames)returndata_tupleif __name__ == "__main__":#获取label和type的对应关系
label_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
name_list= ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
label_map=dict(zip(label_list, name_list))
with tf.compat.v1.Session() as sess:
batch_data= get_data_from_files('cifar10_dir/cifar-10-batches-bin')#在之前的旧版本中,因为使用了filename_queue,所以要使用start_queue_runners进行数据填充
#1.14.0由于没有使用filename_queue所以不需要
#threads = tf.train.start_queue_runners(sess=sess)
sess.run(tf.compat.v1.global_variables_initializer())#创建一个文件夹用于存放图片
if not os.path.exists('cifar10_dir/raw'):
os.mkdir('cifar10_dir/raw')#存放30张,以index-typename.jpg命名,例如1-frog.jpg
for i in range(30):#获取一个batch的数据,BATCH_SIZE
#batch_data中包含一个batch的image和label
batch_data_tuple =sess.run(batch_data)#打印(128, 1)
print(batch_data_tuple[0].shape)#打印(128, 32, 32, 3)
print(batch_data_tuple[1].shape)#每个batch存放第一张图片作为实验
Image.fromarray(batch_data_tuple[1][0]).save("cifar10_dir/raw/{index}-{type}.jpg".format(
index=i, type=label_map[batch_data_tuple[0][0][0]]))