TensorFlow中cnn-cifar10样例代码详解

       TensorFlow是一个支持分布式的深度学习框架,在Google的推动下,它正在变得越来越普及。我最近学了TensorFlow教程上的一个例子,即采用CNN对cifar10数据集进行分类。在看源代码的时候,看完后有一种似懂非懂的感觉,又考虑到这个样例涵盖了tensorflow的大部分语法知识,包括QueueRunners机制、Tensorboard可视化和多GPU数据并行编程等。“纸上得来终觉浅,绝知此事要躬行”,于是我便照着自己的理解和记忆来重新编程实现一遍,实现过程中也遇到了一些问题,在这里把我对程序的详细的理解和遇到的问题记录下来,本着“知其然,知其所以然”的标准,程序中的注释更加侧重于对‘为什么这一行代码要这样写’的解读,以供以后参考(关于卷积神经网络相关的理论部分,网上有很多了,这里就不做太多介绍)。

一个标准的机器学习程序,应该包括数据输入、定义模型本身、模型训练和模型性能测试四大部分,可以分成四个.py文件。

(一)数据输入部分(input_dataset.py)

       从概念上来说,这部分主要是关于数据管道(data pipe)的构建,数据流向为“二进制文件->文件名队列->数据队列->读取出的data-batch”。数据块用于输入到深度学习网络中,进行信息的forward propagation,这部分在定义模型本身部分讨论。在定义整个数据管道的时候,会使用到TensorFlow的队列机制,这部分在之前的博文“TensorFlow读取二进制文件数据到队列”中进行了详细讲述。另外,读原数据文件的时候,要结合文件本身的格式,相信编写过c语言读取二进制文件的程序的朋友,对这应该再熟悉不过了,具体的代码如下,

# -*- coding: utf-8 -*-
import os
import tensorflow as tf
# 原图像的尺度为32*32,但根据常识,信息部分通常位于图像的中央,这里定义了以中心裁剪后图像的尺寸
fixed_height = 24
fixed_width = 24
# cifar10数据集的格式,训练样例集和测试样例集分别为50k和10k
train_samples_per_epoch = 50000
test_samples_per_epoch = 10000
data_dir='./cifar-10-batches-bin' # 定义数据集所在文件夹路径
batch_size=128 #定义每次参数更新时,所使用的batch的大小
 
def read_cifar10(filename_queue):
    # 定义一个空的类对象,类似于c语言里面的结构体定义
    class Image(object):
        pass
    image = Image()
    image.height=32
    image.width=32
    image.depth=3
    label_bytes = 1
    image_bytes = image.height*image.width*image.depth
    Bytes_to_read = label_bytes+image_bytes
    # 定义一个Reader,它每次能从文件中读取固定字节数
    reader = tf.FixedLengthRecordReader(record_bytes=Bytes_to_read) 
    # 返回从filename_queue中读取的(key, value)对,key和value都是字符串类型的tensor,并且当队列中的某一个文件读完成时,该文件名会dequeue
    image.key, value_str = reader.read(filename_queue) 
    # 解码操作可以看作读二进制文件,把字符串中的字节转换为数值向量,每一个数值占用一个字节,在[0, 255]区间内,因此out_type要取uint8类型
    value = tf.decode_raw(bytes=value_str, out_type=tf.uint8) 
    # 从一维tensor对象中截取一个slice,类似于从一维向量中筛选子向量,因为value中包含了label和feature,故要对向量类型tensor进行'parse'操作    
    image.label = tf.slice(input_=value, begin=[0], size=[label_bytes])# begin和size分别表示待截取片段的起点和长度
    data_mat = tf.slice(input_=value, begin=[label_bytes], size=[image_bytes])
    data_mat = tf.reshape(data_mat, (image.depth, image.height, image.width)) #这里的维度顺序,是依据cifar二进制文件的格式而定的
    transposed_value = tf.transpose(data_mat, perm=[1, 2, 0]) #对data_mat的维度进行重新排列,返回值的第i个维度对应着data_mat的第perm[i]维
    image.mat = transposed_value    
    return image    
    
def get_batch_samples(img_obj, min_samples_in_queue, batch_size, shuffle_flag):
'''
tf.train.shuffle_batch()函数用于随机地shuffling 队列中的tensors来创建batches(也即每次可以读取多个data文件中的样例构成一个batch)。这个函数向当前Graph中添加了下列对象:
*创建了一个shuffling queue,用于把‘tensors’中的tensors压入该队列;
*一个dequeue_many操作,用于根据队列中的数据创建一个batch;
*创建了一个QueueRunner对象,用于启动一个进程压数据到队列
capacity参数用于控制shuffling queue的最大长度;min_after_dequeue参数表示进行一次dequeue操作后队列中元素的最小数量,可以用于确保batch中
元素的随机性;num_threads参数用于指定多少个threads负责压tensors到队列;enqueue_many参数用于表征是否tensors中的每一个tensor都代表一个样例
tf.train.batch()与之类似,只不过顺序地出队列(也即每次只能从一个data文件中读取batch),少了随机性。
'''
    if shuffle_flag == False:
        image_batch, label_batch = tf.train.batch(tensors=img_obj, 
                                                  batch_size=batch_size, 
                                                  num_threads=4, 
                                                  capacity=min_samples_in_queue+3*batch_size)
    else:
        image_batch, label_batch = tf.train.shuffle_batch(tensors=img_obj, 
                                                          batch_size=batch_size, 
                                                          num_threads=4, 
                                                          min_after_dequeue=min_samples_in_queue,
                                                          capacity=min_samples_in_queue+3*batch_size)                                                    
    tf.image_summary('input_image', image_batch, max_images=6) #输出预处理后图像的summary缓存对象,用于在session中写入到事件文件中                                                    
    return image_batch, tf.reshape(label_batch, shape=[batch_size])     
                                       
def preprocess_input_data():
'''这部分程序用于对训练数据集进行‘数据增强’操作,通过增加训练集的大小来防止过拟合'''
    filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)]
    #filenames =[os.path.join(data_dir, 'test_batch.bin')]
    for f in filenames: #检验训练数据集文件是否存在
        if not tf.gfile.Exists(f):
            raise ValueError('fail to find file:'+f)    
    filename_queue = t
  • 21
    点赞
  • 97
    收藏
    觉得还不错? 一键收藏
  • 90
    评论
评论 90
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值