Tensorflow(四)- CNN_CIFAR(一)- cifar10_input

本文详细介绍了如何使用Tensorflow的cifar10_input.py文件读取和处理CIFAR-10数据集,包括数据读取的三种方式,重点讲解了distorted_input、read_cifar10和generate_image_and_label_batch函数,涉及数据增强、队列管理和数据批处理的过程。
摘要由CSDN通过智能技术生成

这篇开始,讲述自己对于tensorflow文档中,利用CNN建立CIFAR-10模型的理解,如有错误欢迎指正,也是互相学习。由于代码太长,所以分几篇来讲述。
第一篇是关于cifar10_input.py文件。

cifar10_input.py

这个文件主要用来进行数据读取以及输入数据处理。
Tensorflow一共有三种读取数据的方式:
第一种最简单的预加载数据,直接在graph中定义常量和变量来保存数据(仅适用于数据小的情况,神经网络自然这种方法不行)
第二种供给数据(feeding),那么这种方法前面出现过很多次,也就是在图中建立placeholder,然后在跑图的时候再对占位进行数据feed。(如果数据量过大,一次性读入所有数据,再分批次feed进图,也会占用太多的内存空间)
第三种从文件中读取数据,在Tensorflow的起始,用一个输入管线从文件中读取数据。
那么这个模型就是运用了第三种数据读取方式。
这里我们按照实际建立图的顺序来进行讲解,由cifar10_train.py文件我们知道先调用的是distorted_input函数。

distorted_input

函数输入为data_dir(训练集所在文件夹),以及batch_size。
函数输出为图像组成的4维tensor,以及labels组成的1维tensor。
下面为数据读取步骤:

  1. 生成文件名列表,也就是把需要输入的所有文件的文件名放到一个列表里。
  2. 将文件名列表输入到tf.train.string_input_producer()函数中,生成一个先入先出的文件名队列,同时将一个QueueRunner添加到整个图的QUEUE_RUNNER当中。(tf.train.QueueRunner本质上是tensorflow的一个类,用来完成队列的一系列入队操作)
    下面为API文档,具体请参考文档。
    https://www.tensorflow.org/api_docs/python/tf/train/string_input_producer
  3. 利用文件阅读器读取文件名队列中文件里的数据。接下来我们跟随代码转移到read_cifar10函数中。
  4. 从read_cifar10函数回来,这个时候read_input已经是刚刚输出的结构体了,它有一个样本的所有信息。然后对图像数据进行一系列预处理,这是data augmentation,包括随机裁取,随机左右翻转,随机亮度调节,随机对比度调节,图像归一化。
    所有关于图像操作的API文档
    https://www.tensorflow.org/api_guides/python/image
  5. 定义参数min_queue_examples,然后进入到generate_image_and_label_batch()函数中。
def distorted_inputs(data_dir, batch_size):
    # 生成文件名列表
    filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
                 for i in range(1, 6)]
    for f in filenames:
        if not gfile.Exists(f):
            raise ValueError('Failed to find file: ' + f)
    # 生成文件名队列
    filename_queue 
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值