1 importcifar10, cifar10_input2 importtensorflow as tf3 importnumpy as np4 importtime5 importmath6
7 max_steps = 3000
8 batch_size = 128
9 data_dir = '/tmp/cifar10_data/cifar-10-batches-bin'
10
11
12 defvariable_with_weight_loss(shape, stddev, w1):13 '''定义初始化weight函数,使用tf.truncated_normal截断的正态分布,但加上L2的loss,相当于做了一个L2的正则化处理'''
14 var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))15 '''w1:控制L2 loss的大小,tf.nn.l2_loss函数计算weight的L2 loss'''
16 if wl is notNone:17 weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')18 '''tf.add_to_collection:把weight losses统一存到一个collection,名为losses'''
19 tf.add_to_collection('losses', weight_loss)20
21 returnvar22
23
24 #使用cifar10类下载数据集并解压展开到默认位置
25 cifar10.maybe_download_and_extract()26
27 '''distored_inputs函数产生训练需要使用的数据,包括特征和其对应的label,28 返回已经封装好的tensor,每次执行都会生成一个batch_size的数量的样本'''
29 images_train, labels_train = cifar10_input.distored_inputs(data_dir=data_dir,30 batch_size=batch_size)31
32 images_test, l