一、Tensorflow 数据读取机制
1、tensorflow数据读包括两步:1)读取,2)计算。假设读取需花费0.1s,计算花费0.9s,则在读取数据时,GPU会有0.1s无事可做。这将大大降低GPU的运行效率。
下图为tensorflow数据读取 步骤 图示:
2、为提高GPU的运行效率,tensorflow采用“多线程”的方式,进行数据读取,线程一将数据从文件系统读入“内存队列”,线程二从“内存队列”中读取数据进行计算。采用这种方式,可以省去GPU等待I/O的时间,增加GPU效率。
下图为改进的tensorflow数据读取 步骤 图示:
3、在tensorflow中,为了方便管理,在上述设计的基础上,又加入了一个“文件名队列”。这是由于在“机器学习”中,会涉及到epoch的运用,加入“文件名队列”后,可以先将指定epoch 的 数据 读入“文件名队列”,并以“结束符”结尾,然后,在将这些数据读入到“内存队列”,提供给GPU计算。
下图为加入“文件名队列”的tensorflow 数据读取 流程:
4、上述过程可以简单用以下几个函数完成:
- 将指定epoch的数据读入“文件名队列”,该过程用如下函数实现:
tf.train.string_input_producer(filename,shuffle=False,num_epochs=5)
- 从“文件名队列”读取数据,该过程用如下函数实现:
reader = tf.WholeFileReader()
key,value = reader.read(filename_queue)
- 从“文件名队列”读取数据后,即可启动“填充队列线程”,将数据读入“内存队列”,填充队列线程 激活函数如下:
threads = tf.train.start_queue_runners(sess=session)
- 启动“填充线程”后,即可将数据导入“内存队列”,供GPU计算,获得数据。
5、下面为tensorflow数据读取code:
#tensorflow数据读取
#导入tensorflow
import tensorflow as tf
#新建一个Session
with tf.Session as sess:
filename=['A.jpg','B.jpg','C.jpg']
#string_input_producer产生文件名队列
filename_queue=tf.train.string_input_producer(filename,shuffle=False,num_epochs=5)
#利用reader从文件名列队中读取数据
reader =tf.WholeFileReader()
key,value=reader.read(filename_queue)
#对string_input_producer中变量epoch进行初始化???
tf.local_variables_initializer().run()
#启动填充线程,将读取出的数据,填入内存队列
threads =tf.train.start_queue_runners(sess=sess)
i=0
while True:
i +=1
#从Session中读取数据
image_data=sess.run(value) #image_data为二进制数据
#将读取的数据存入文件夹read中
with open('read/test_%d.jpg' % i,'wb') as f:
f.write(image_data)
二,数据增强(data augmentation)
1、数据增强:在深度学习中,通常要求要有充足的数据量,数据量越大,model训练效果越好。所谓数据增强,是指通过“平移”,“旋转”,“缩放”,“裁剪”,“翻转”,“颜色变换”,“噪声扰动”等手段,人工增加数据样本量,从而获得更加充足的训练样本,使得model泛化能力更好,防止overfitting的发生。
使用“数据增强”的以前前提是,不能改变样本label,如在“手写数字识别”中,如果使用“翻转”进行“数据增强”,则将label=6变为label=9,因此,在该案例中,不能使用“翻转”进行数据增强。
#随机裁剪图片
distorted_image = tf.random_crop(reshaped_image,[height,width,3])
#随机翻转
distorted_image = tf.image.random_flip_left_right(distorted_image)
#随机改变亮度和对比度
distorted_image = tf.image.random_brightness(distorted_image,max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image,lower=0.2,upper=0.8)
三、Tensorflow中读入数据大致有3种方法:
1、用占位符(placeholder)读入;
2、用队列的形式,建立文件到Tensor的映射;
3、用DataSet API读入数据;