cifar10数据分类tensorflow官方源码解读

#一、基础知识
1、单词积累:
(1)dequeue:出列 取数据
(2)enqueue:入列 存数据
(3)queue:队列
2、白化处理参考这里:白化的目的是去除输入数据的冗余信息。假设训练数据是图像,由于图像中相邻像素之间具有很强的相关性,所以用于训练时输入是冗余的;白化的目的就是降低输入的冗余性。
输入数据集X,经过白化处理后,新的数据X’满足两个性质:
(1)特征之间相关性较低;
(2)所有特征具有相同的方差。
其实我们之前学的PCA算法中,可能PCA给我们的印象是一般用于降维操作。然而其实PCA如果不降维,而是仅仅使用PCA求出特征向量,然后把数据X映射到新的特征空间,这样的一个映射过程,其实就是满足了我们白化的第一个性质:除去特征之间的相关性。因此白化算法的实现过程,第一步操作就是PCA,求出新特征空间中X的新坐标,然后再对新的坐标进行方差归一化操作。
3、可以在Images页的列表中查看所有可用的数据增强的变换,对于每个原始图还附带了一个image_summary,以便于在TensorBoard中查看。这对于检查输入图像是否正确十分有用。
4、局部响应归一化LRN(Local Response Normalization):对局部神经元的活动创建竞争机制,使得其中响应比较大的值变得相对更大,并抑制其他反馈较小的神经元,增强了模型的泛化能力。
5、TensorFlow Datasets(GitHub),它将公共研究数据集公开为tf.data.Datasets和NumPy数组。 它完成了获取源数据并将其准备为磁盘上的通用格式的所有工作,并使用tf.data API构建高性能输入管道,这些管道准备好TensorFlow 2.0并可与tf.keras模型一起使用。
6、前后均带有双下划线__的命名
一般用于特殊方法的命名,用来实现对象的一些行为或者功能,比如__new__()方法用来创建实例,init()方法用来初始化对象;x + y操作被映射为方法x.add(y),序列或者字典的索引操作x[k]映射为x.getitem(k),len()、str()分别被内置函数len()、str()调用等等。
#二、源码解读
##1、cifar10_data.py
主要有2个函数:read_cifar10(file_queue)和inputs(data_dir, batch_size, distorted)
(1)读取源二进制文件中的数据,转化为训练标准数据
(2)数据增强,进一步转化为训练数据集和测试数据集

import os
import tensorflow as tf

num_classes = 10
num_examples_pre_epoch_for_train = 50000
num_examples_pre_epoch_for_eval = 10000

class CIFAR10Record(object):
    pass

def read_cifar10(file_queue):
    result = CIFAR10Record()
    label_bytes = 1
    result.height = 32
    result.width = 32
    result.depth = 3
    image_bytes = result.height * result.width * result.depth
    record_bytes = label_bytes + image_bytes
    reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)# 读取固定长度
    result.key, value = reader.read(file_queue)# key是文件名,value是图像和标签
    record_bytes = tf.decode_raw(value,tf.uint8)# uint8是8位无符号整型
    result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)# 将label转化为int32数据类型
    depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]),\
                             [result.depth, result.height, result.width])# 将image由一维字符串转化为,depth,height,width列表格式
    result.uint8image = tf.transpose(depth_major, [1, 2, 0])# 转化为 height, width,depth的格式
    return result
#训练数据增强,测试不增强
def inputs(data_dir, batch_size, distorted):
    filenames = [os.path.join(data_dir, "data_batch_%d" % i) for i in range(1, 6)]# 获得文件路径
    file_queue = tf.train.string_input_producer(filenames)# 将文件分16进程放入队列
    read_input = read_cifar10(file_queue)# 读取队列
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)# 将image转化为float32格式
    num_examples_per_epoch = num_examples_pre_epoch_for_train
    if distorted != None:# 数据增强
        cropped_image = tf.random_crop(reshaped_image, [24, 24, 3])
        filpped_image = tf.image.random_flip_left_right(cropped_image)
        adjusted_brightness = tf.image.random_brightness(filpped_image, max_delta=0.8)
        adjusted_contrast = tf.image.random_contrast(adjusted_brightness, lower=0.2, upper=1.8)
        float_image = tf.image.per_image_standardization(adjusted_contrast) #标准化, 非归一化

        float_image.set_shape([24, 24, 3])# 设定图片大小
        read_input.label.set_shape([1])# 设定标签大小
        min_queue_examples = int(num_examples_pre_epoch_for_eval * 0.4)# 队列最少数据量
        print('Filling queue with %d CIFAR images before starting to train. This will take a few minutes.'\
                % min_queue_examples)

        images_train, labels_train = tf.train.shuffle_batch([float_image, read_input.label], batch_size=batch_size,\
                                                            num_threads=16, capacity=min_queue_examples + 3 * batch_size,\
                                                            min_after_dequeue = min_queue_examples)#随机选择选练数据
        return images_train, tf.reshape(labels_train, [batch_size])
    else:
        resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, 24, 24)
        float_image = tf.image.per_image_standardization(resized_image)
        float_image.set_shape([24, 24, 3])
        read_input.label.set_shape([1])
        min_queue_examples = int(num_examples_per_epoch * 0.4)
        images_test, labels_test = tf.train.batch([float_image, read_input.label], batch_size=batch_size, num_threads=16,\
                                                  capacity=min_queue_examples + 3*batch_size)
        return images_test, tf.reshape(labels_test, [batch_size])

##2、CNN_Cifar10.py
主要功能:
(1)获得训练测试数据;
(2)输入输出限流规范化(两个接口)
(3)构建主干网络
(4)损失函数与训练方法(训练方案)
(5)建立会话

import tensorflow as tf
import numpy as np
import time
import math
import cifar10_data

max_steps = 4000
batch_size = 100
num_examples_for_eval = 10000
data_dir = "/home/promise/Downloads/cifar-10-python/cifar-10-batches-py"

def variable_with_weight_loss(shape, stddev, w1):
    var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))# 权重初始化
    if w1 is not None:
        weights_loss = tf.multiply(tf.nn.l2_loss(var), w1, name = "weights_loss")# l2权重正则化
        tf.add_to_collection("losses", weights_loss)# 正则化加入损失函数add_to_collection()
    return var

images_train, labels_train = cifar10_data.inputs(data_dir=data_dir, batch_size=batch_size, distorted = True)#数据增强
images_test, labels_test = cifar10_data.inputs(data_dir=data_dir, batch_size=batch_size, distorted=None)#数据不增强

#模型进出口限流规范化
x= tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
y_ = tf.placeholder(tf.int32, [batch_size])

kernel1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2, w1=0.0)#卷积核初始化,存在权重正则化
conv1 = tf.nn.conv2d(x, kernel1, [1, 1, 1, 1], padding="SAME")#卷积
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))#偏置项初始化,初始值为0
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, bias1))# 激活
pool1 = tf.nn.max_pool(relu1, ksize=[1, 3, 3, 1], strides=[1,2,2,1], padding="SAME")# 最大池化,有重复滑动可提高模型性能
kernel2 = variable_with_weight_loss(shape=[5,5,64,64], stddev=5e-2, w1=0.0)
conv2 = tf.nn.conv2d(pool1, kernel2, [1,1,1,1], padding="SAME")
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, bias2))
pool2 = tf.nn.max_pool(relu2, ksize=[1,3,3,1], strides=[1,2,2,1], padding="SAME")

reshape = tf.reshape(pool2, [batch_size, -1])#将输出转化为(batch * features)格式
dim = reshape.get_shape()[1].value #列为特征维度

weight1 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004)#全链接层初始化需要计算权重正则化
fc_bias1 = tf.Variable(tf.constant(0.1, shape=[384]))
fc_1 = tf.nn.relu(tf.matmul(reshape, weight1) + fc_bias1)
weight2 = variable_with_weight_loss(shape=[384,192], stddev=0.04, w1=0.004)
fc_bias2 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(fc_1, weight2) + fc_bias2)
weight3 = variable_with_weight_loss(shape=[192, 10], stddev=1/192.0, w1=0.0)
fc_bias3 = tf.Variable(tf.constant(0.0, shape=[10]))
result = tf.add(tf.matmul(local4, weight3), fc_bias3)

cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=result, labels=tf.cast(y_, tf.int64))#分类问题采用稀疏化交叉熵损失
weights_with_l2_loss = tf.add_n(tf.get_collection("losses"))#add_n(tf.get_collection())与tf.add_to_collection()函数绑定使用
loss = tf.reduce_mean(cross_entropy) + weights_with_l2_loss# 总损失
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss) # Adam梯度下降优化算法
top_k_op = tf.nn.in_top_k(result, y_, 1)

sess =tf.InteractiveSession()
tf.global_variables_initializer().run()
tf.train.start_queue_runners()# 采用队列时要在建立的会话中先激活
for step in range(max_steps):
    start_time = time.time()# 计时开始在第一次循环迭代开始
    image_batch, label_batch = sess.run([images_train, labels_train])# 加载训练数据
    _, loss_value = sess.run([train_op, loss], feed_dict={x:image_batch, y_: label_batch})
    duration = time.time() - start_time #记录一批次数据训练时间
    if step % 100 == 0:
        examples_per_sec = batch_size/duration
        sec_per_batch = float(duration)
        print("step %d, loss = %.2f(%.1f examples/sec;) %.3f sec/batch" % (step, loss_value, examples_per_sec, sec_per_batch))

num_batch = int(math.ceil(num_examples_for_eval/batch_size))#math.ceil()向上取整
true_count = 0
total_sample_count = num_batch * batch_size
for i in range(num_batch):
    image_batch, label_batch = sess.run([images_test, labels_test])
    predictions = sess.run([top_k_op], feed_dict={x:image_batch, y_:label_batch})
    true_count += np.sum(predictions)

print("accuracy = %.3f%%" % ((true_count/total_sample_count)*100))

#三、存留疑问
1、卷积核、偏置项、全连接曾的权重初始化依据是什么?

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值