from models.tutorials.image.cifar10 import cifar10,cifar10_input
import tensorflow as tf
import time
# -----------定义常量-------------------
max_steps = 3000
batch_size = 128
data_dir = 'cifar-10-batches-bin'
L2范式:对模型进行正则化,增加一个正则化项,正则化项是模型复杂度的单调递增函数,模型越复杂正则化项就越大。可以防止参数过多模型过于复杂导致的过拟合现象。L2范数为1/||w||**2。
tf.nn.l2_loss(t, name=None)
解释:这个函数的作用是利用 L2 范数来计算张量的误差值,但是没有开方并且只取 L2 范数的值的一半,具体如下:
output = sum(t ** 2) / 2
使用tf.multiply()将w1与L2范式相乘来控制正则化的大小。使用collection将loss值存储在loss集合中.
# ---------------定义初始化函数--------------
def variable_with_weight_loss(shape,stddev,w1):
var = tf.Variable(tf.truncated_normal(shape,stddev=stddev))
if w1 is not None:
weight_loss = tf.multiply(tf.nn.l2_loss(var),w1,name='weight_loss')
tf.add_to_collection('losses',weight_loss)
return var
# --------------生成训练数据和测试数据-----------------
image_train,label_train = cifar10_input.distorted_inputs(data_dir=data_dir,batch_size=batch_size)
image_test,lable_test = cifar10_input