具体来说,我认为数据增强肯定是要把tfrecord转为图片,才能增强嘛。tfrecord格式读取更快,比图片读取。
当然有人说我就是喜欢拿原来图片直接数据增强,额,简单直接,能达到目的也行。
再一个上篇我实现了image-->tfrecord格式,所以为了验证对否,还是要转回来看看。
总的来说只有3步:
1.读取tfrecords,只是读取器变成了tf.TFRecordReader来读取tfrecord文件。
2.通过一个解析器tf.parse_single_example ,解析这个特殊的tfrecord格式文件。
3.然后用解码器 tf.decode_raw 解码。
========
效果如下:
直接上代码:
# -*- coding: utf-8 -*-
import tensorflow as tf
from PIL import Image
#写入将要保存图片路径,需要自己手动新建文件夹
swd = './tfrecord2pic'+'/'
#TFRecord文件路径,只能打开某一个具体的tfrecord,有多个那就改一下咯。
data_path = './traindata.tfrecords-003'
# 获取文件名列表
data_files = tf.gfile.Glob(data_path)
# 文件名列表生成器
filename_queue = tf.train.string_input_producer(data_files,shuffle=True)
reader = tf.TFRecordReader()
#上一篇说了,tfrecord格式数据度保存在值里面,即serialized_example,所以键不管
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
'img_width': tf.FixedLenFeature([], tf.int64),
'img_height': tf.FixedLenFeature([], tf.int64),
}) #取出包含image和label的feature对象
#tf.decode_raw可以将字符串解析成图像对应的像素数组
image = tf.decode_raw(features['img_raw'], tf.uint8)
height = tf.cast(features['img_height'],tf.int32)
width = tf.cast(features['img_width'],tf.int32)
label = tf.cast(features['label'], tf.int32)
channel = 3
image = tf.reshape(image, [height,width,channel])
with tf.Session() as sess: #开始一个会话
init_op = tf.initialize_all_variables()
sess.run(init_op)
#启动多线程
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
#循环6次,所以转化了6张图片
for i in range(6):
single,l = sess.run([image,label])#在会话中取出image和label
img=Image.fromarray(single, 'RGB')#这里Image是之前提到的
#存下图片,格式是 第几张图片_label_所属类别标签号
img.save(swd+str(i)+'_''Label_'+str(l)+'.jpg')
coord.request_stop()
coord.join(threads)
====提示
1.红色字体是需要修改的部分,看注释改下吧。
2.每次读取都是从tfrecord第一张图片开始读取的,for i in range(xxx,xxxx):这里设置图片的编号,
同时还设置转化的图片数目。你如果想从第10张以后开始读取保存图片,简单:
你写个if语句,判断循环了多少次嘛。到了第10次才开始保存即可。
或者你想跳过中间某些图片不处理,还是写个if语句,count在那个范围之内你再读取嘛。
3.比如你tfrecord有300张图片,你设置for i in range(xxx,xxxx)读取500张,
其实后面200张图片又从tfrecord头开始读取重复图片了,不会保错。