学习笔记之思路整理

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/hensonwells/article/details/75127523

1.图片处理:(流程被分配在16个线程中处理)
图片会被统一裁剪到24x24像素大小,裁剪中央区域用于评估或随机裁剪用于训练;
图片会进行近似的白化处理,使得模型对图片的动态范围变化不敏感。
对图像进行随机的左右翻转;
随机变换图像的亮度;
随机变换图像的对比度;

训练方法与损失的定义:
训练一个可进行N维分类的网络的常用方法是使用多项式逻辑回归(softmax 回归),
Softmax 回归在网络的输出层上附加了一个softmax nonlinearity,
并且计算归一化的预测值和label的1-hot encoding的交叉熵。
在正则化过程中,对所有学习变量应用权重衰减损失(使用了L2范式,强调模型的参数的稀疏性),
求交叉熵损失和所有权重衰减项的和,loss()函数的返回值就是这个值

2.数据读取:
(1)读取:
读取文件队列名,用read_cifar10()来获取一个样本的信息结构体(大小、数据、标签),
使用tf.cast转换uint8成float32
(2)切割:read_cifar10(),该函数从二进制数据中读取数据并规整,
每条样本都是先标签后数据,CIFAR10是一个字节标签,
CIFAR100是2字节,使用切片函数tf.slice()
(3)处理原始图片:初步获取数据后就需要变形成tensor了, tf.random_crop(reshaped_image,[height,
width,3]) 1D变换成3D,对图像进行了很多随机扭曲处理…通过tf.train.shuffle_batch中设定队列大小、缓冲区大小,
直接就保证整理好一个数据集合的队列

3.建立训练网络:
(1)参数设置函数
_variable_with_weight_decay(name,shape,stddev,wd)
对应功能:输入名称、形状、偏差和均值 就定义一个参数tensor
(2)生成数据
先设置常量参数,再由tf.nn.l2_loss(var)增加L2范式稀疏化
L2范式定义为:output = sum(t ** 2) / 2,然后乘以一个衰减系数wd做为一个训练指标:
这个值应该尽量小,以保证稀疏性
用tf.add_to_collection(‘losses’,weight_decay)把所有的系数作为以losses为标签进行收集
用summary用于查看输出的稀疏性:tf.scalar_summary(tensor_name+’/sparsity’,
tf.nn.zero_fraction(x)),统计0的比例反应稀疏性。tf.histogram_summary(tensor_name+’/activations’,
x),输出数值的分布直接反应神经元的活跃性,如果全是很小的值说明不活跃。
全连接层的展开维度192、384这些数字与GPU的架构有关,全连接层的wd是0.004,略微强调了一下稀疏性

ps:多个GPU需要tf.get_variable()用于分享数据,而单个GPU只需要tf.Variable()

4.损失函数:
cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits,
labels,name=’cross_entropy_per_example’)

5.训练:
(1)学习率更新:首先是根据当前的训练步数、衰减速度、之前的学习速率确定新的学习速率
lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE……… staircase=True)
式子:decayed_learning_rate=learening_rate*decay_rata^(global_step/decay_steps)
如果staircase=True则取整数
(2)均值线(ExponentialMovingAverage)
(3)计算梯度及更新梯度compute_gradients,opt.apply_gradients反向传播
(4)summary和句柄

6.测试模型
(1)传入验证函数的参数:
eval_once(saver, saver是用读取moving_average的
summary_writer, summary_writer和summary_op是保存记录的
top_k_op, top_k_op传入了模型和验证模型
summary_op)
(2) 读取检查点:
ckpt=tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
从检查点恢复图和参数:
saver.restore(sess,ckpt.model_checkpoint_path)

阅读更多
想对作者说点什么?

博主推荐

换一批

没有更多推荐了,返回首页