每个step,train_op之后,求滑动平均。
#定义三层卷积网络 ,假设输入的是shape(batchsize,128,128,3)的图片
def simplenet(inputs,y,class_num,istrain=True):
with tf.variable_scope('conv1'):
W_conv1 = weight_variable([5,5,3,32])
b_conv1 = bias_variable([32]) #大小和输出通道数一致
gamma_conv1 = tf.Variable(tf.constant(1.0,shape=[1]))
theta_conv1 = tf.Variable(tf.constant(0.0,shape=[1]))
net = conv2d(inputs,W_conv1) + b_conv1 #:大小:128-5+2*2/1+1=128 其实不用算,padding目的就是维持大小和输入一致
net, _, _ = batch_norm(net,gamma_conv1,theta_conv1,istrain)
net = tf.nn.relu(net)
net = max_pool_2x2(net) # shape(batchsize,64,64,32)
with tf.variable_scope('conv2'):
W_conv2 = weight_variable([3,3,32,64])
b_conv2 = bias_variable([64])
gamma_conv2 = tf.Variable(tf.constant(1.0,shape=[1]))
theta_conv2 = tf.Variable(tf.constant(0.0,shape=[1]))
net = conv2d(net,W_conv2) + b_conv2
net, _, _ = batch_norm(net,gamma_conv2,theta_conv2,istrain)
net = tf.nn.relu(net)
net = max_pool_2x2(net) # shape(batchsize,32,32,64)
with tf.variable_scope('conv3'):
W_conv3 = weight_variable([3,3,64,128])
b_conv3 = bias_variable([128])
gamma_conv3 = tf.Variable(tf.constant(1.0,shape=[1]))
theta_conv3 = tf.Variable(tf.constant(0.0,shape=[1]))
net = conv2d(net,W_conv3) + b_conv3
net, _, _ = batch_norm(net,gamma_conv3,theta_conv3,istrain)
net = tf.nn.relu(net)
net = max_pool_2x2(net) # shape(batchsize,16,16,128)
net_flat = tf.reshape(net,[batchsize,16*16*128])
with tf.variable_scope('fc'):
W_fc1 = weight_variable([16*16*128,class_num])
b_fc1 = bias_variable([class_num])
logits = tf.nn.relu(tf.matmul(net_flat,W_fc1) + b_fc1) # shape(batchsize,class_num)
cross_loss = loss(y,logits)
train_op = optimizer1(cross_loss,lr)
ema = tf.train.ExponentialMovingAverage(decay=0.9999) ### 变
with tf.control_dependencies([train_op]): ### 化
maintain_averages_op = ema.apply() ### 部
shadow_W_conv1 = ema.average(W_conv1) ### 分
return logits, train_op, cross_loss, maintain_averages_op, shadow_W_conv1
logits, train_op, loss, maintain_averages_op, shadow_W_conv1 = simplenet(x,y,class_num)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
steps = epochs * len(img_data) // batchsize
for step in range(steps):
batch_inputs = inputs[step*batchsize:(step+1)*batchsize]
batch_labels = true_labels[step*batchsize:(step+1)*batchsize]
ls, tr, s = sess.run([loss,maintain_averages_op,shadow_W_conv1],feed_dict={x:batch_inputs,y:batch_labels})
if step%100 == 0:
print(step,' step: ',' loss is ', ls, s) # ls 损失,s,W_conv1对应的滑动平均值
运行: