chapter2. CIFAR-10与ImageNet图像识别

一、Tensorflow 数据读取机制

1、tensorflow数据读包括两步:1)读取,2)计算。假设读取需花费0.1s,计算花费0.9s,则在读取数据时,GPU会有0.1s无事可做。这将大大降低GPU的运行效率。
下图为tensorflow数据读取 步骤 图示:
在这里插入图片描述
2、为提高GPU的运行效率,tensorflow采用“多线程”的方式,进行数据读取,线程一将数据从文件系统读入“内存队列”,线程二从“内存队列”中读取数据进行计算。采用这种方式,可以省去GPU等待I/O的时间,增加GPU效率。
下图为改进的tensorflow数据读取 步骤 图示:

3、在tensorflow中,为了方便管理,在上述设计的基础上,又加入了一个“文件名队列”。这是由于在“机器学习”中,会涉及到epoch的运用,加入“文件名队列”后,可以先将指定epoch 的 数据 读入“文件名队列”,并以“结束符”结尾,然后,在将这些数据读入到“内存队列”,提供给GPU计算。
下图为加入“文件名队列”的tensorflow 数据读取 流程:

4、上述过程可以简单用以下几个函数完成:

  1. 将指定epoch的数据读入“文件名队列”,该过程用如下函数实现:
tf.train.string_input_producer(filename,shuffle=False,num_epochs=5)
  1. 从“文件名队列”读取数据,该过程用如下函数实现:
reader = tf.WholeFileReader()
key,value = reader.read(filename_queue)
  1. 从“文件名队列”读取数据后,即可启动“填充队列线程”,将数据读入“内存队列”,填充队列线程 激活函数如下:
threads = tf.train.start_queue_runners(sess=session)
  1. 启动“填充线程”后,即可将数据导入“内存队列”,供GPU计算,获得数据。
    5、下面为tensorflow数据读取code:
#tensorflow数据读取
#导入tensorflow
import tensorflow as tf
#新建一个Session
with tf.Session as sess:
    filename=['A.jpg','B.jpg','C.jpg']
    #string_input_producer产生文件名队列
    filename_queue=tf.train.string_input_producer(filename,shuffle=False,num_epochs=5)
    #利用reader从文件名列队中读取数据
    reader =tf.WholeFileReader()
    key,value=reader.read(filename_queue)
    #对string_input_producer中变量epoch进行初始化???
    tf.local_variables_initializer().run()
    #启动填充线程,将读取出的数据,填入内存队列
    threads =tf.train.start_queue_runners(sess=sess)
    i=0
    while True:
        i +=1
        #从Session中读取数据
        image_data=sess.run(value) #image_data为二进制数据
        #将读取的数据存入文件夹read中
        with open('read/test_%d.jpg' % i,'wb') as f:
            f.write(image_data)

二,数据增强(data augmentation)

1、数据增强:在深度学习中,通常要求要有充足的数据量,数据量越大,model训练效果越好。所谓数据增强,是指通过“平移”,“旋转”,“缩放”,“裁剪”,“翻转”,“颜色变换”,“噪声扰动”等手段,人工增加数据样本量,从而获得更加充足的训练样本,使得model泛化能力更好,防止overfitting的发生。
使用“数据增强”的以前前提是,不能改变样本label,如在“手写数字识别”中,如果使用“翻转”进行“数据增强”,则将label=6变为label=9,因此,在该案例中,不能使用“翻转”进行数据增强。

#随机裁剪图片
distorted_image = tf.random_crop(reshaped_image,[height,width,3])
#随机翻转
distorted_image = tf.image.random_flip_left_right(distorted_image)
#随机改变亮度和对比度
distorted_image = tf.image.random_brightness(distorted_image,max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image,lower=0.2,upper=0.8)

三、Tensorflow中读入数据大致有3种方法:

1、用占位符(placeholder)读入;
2、用队列的形式,建立文件到Tensor的映射;
3、用DataSet API读入数据;

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Sarah ฅʕ•̫͡•ʔฅ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值