首先是用来训练的会话:
#定义反向传播方法:不含正则化
train_step = tf.train.AdamOptimizer(0.0001).minimize(loss_mse)
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 40000
for i in range(STEPS):
start = (i*BATCH_SIZE)%300
end = start + BATCH_SIZE
sess.run(train_step,feed_dict={x:X[start:end],y_:Y_[start:end]})
if i % 2000 == 0:
loss_mse_v = sess.run(loss_mse,feed_dict={x:X,y_:Y_})
print("After " + str(i)+ " steps, loss is: " + str(loss_mse_v))
无论是正则化也好,无正则化程序也好,训练步骤基本上都是这样,正则化程序无非是改改 train_step 罢了:
#定义反向传播方法:包含正则化
train_step = tf.train.AdamOptimizer(0.0001).minimize(loss_total)
以及每2000次打印的程序:
loss_total_v = sess.run(loss_total,feed_dict={x:X,y_:Y_})
print("After " + str(i)+ " steps, loss is: " + str(loss_total_v))
通过训练,我们得到了目标网络参数。
之后我们要把一些数据放进去训练,然后把结果用来标注:
#xx在-3到3之间以步长为0.01,yy在-3到3之间以步长为0.01,生成二维网格坐标点
xx,yy = np.mgrid[-3:3:.01,-3:3:.01]
#将xx,yy拉直,并合并成一个2列的矩阵,得到一个网格坐标点的集合
grid = np.c_[xx.ravel(),yy.ravel()]
#将网格坐标点喂入神经网络,probs为输出
probs = sess.run(y,feed_dict={x:grid})
#probs的shape调整成xx的样子
probs = probs.reshape(xx.shape)
print("w1: "+ str(sess.run(w1)))
print("b1: "+ str(sess.run(b1)))
print("w2: " + str(sess.run(w2)))
print("b2: " + str(sess.run(b1)))
plt.scatter(X[:,0],X[:,1],c=np.squeeze(Y_c))
plt.contour(xx,yy,probs,levels=[.5])
plt.show()
这里有个地方比较难理解,我们详细说一下:
#xx在-3到3之间以步长为0.01,yy在-3到3之间以步长为0.01,生成二维网格坐标点
xx,yy = np.mgrid[-3:3:.01,-3:3:.01]
#将xx,yy拉直,并合并成一个2列的矩阵,得到一个网格坐标点的集合
grid = np.c_[xx.ravel(),yy.ravel()]
首先是第一行,生成二维网格坐标点,我们把xx,yy分别打印。得到如下结果:
[[-3. -3. -3. ... -3. -3. -3. ]
[-2.99 -2.99 -2.99 ... -2.99 -2.99 -2.99]
[-2.98 -2.98 -2.98 ... -2.98 -2.98 -2.98]
...
[ 2.97 2.97 2.97 ... 2.97 2.97 2.97]
[ 2.98 2.98 2.98 ... 2.98 2.98 2.98]
[ 2.99 2.99 2.99 ... 2.99 2.99 2.99]]
[[-3. -2.99 -2.98 ... 2.97 2.98 2.99]
[-3. -2.99 -2.98 ... 2.97 2.98 2.99]
[-3. -2.99 -2.98 ... 2.97 2.98 2.99]
...
[-3. -2.99 -2.98 ... 2.97 2.98 2.99]
[-3. -2.99 -2.98 ... 2.97 2.98 2.99]
[-3. -2.99 -2.98 ... 2.97 2.98 2.99]]
可以看到,上面的x每行数都是相同的,对于下面y每行逐步增加0.01。
我们用第二行程序把它们拉直,就变成了如下所示的数据:
[[-3. -3. ]
[-3. -2.99]
[-3. -2.98]
...
[ 2.99 2.97]
[ 2.99 2.98]
[ 2.99 2.99]]
就是-3到3上每隔0.01所有的网格点了。我们把这些网格点放进训练好的网络中,就能得到结果。
probs = sess.run(y,feed_dict={x:grid})
然后调整输出的probs的格式:
#probs的shape调整成xx的样子
probs = probs.reshape(xx.shape)
probs的形状要和xx,yy一样,因为待会我们画坐标点的时候就是要根据xx,yy作为坐标来画的。现在的probs里面的数据都是0到1之间的小数。我们判断依据是大于0.5的时候表示在圆内,小于0.5的时候不在圆内。
plt.contour(xx,yy,probs,levels=[.5])
这个函数主要对网格中每个点的值等于level的时候做出轮廓线,相当于把大于level和小于level的部分分隔开。
得到的最终结果如图所示:左边是不使用正则化的训练结果,右边是使用正则化以后的训练结果。