引入matplotlib来查看拟合效果
首先对随机生成的数据进行可视化,用于查看我们人为生成的数据的分布情况
# 引入可视化库
import matplotlib.pyplot as plt
# 这部分为上篇中生成数据的代码
x_data = np.linspace(-1, 1, 300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise
# 接下里调用可视化函数,用于查看现在的数据样式
# 获取Figure对象
fig = plt.figure()
# 每个Figure对象可以拥有多个子图(Axes),使用subplot函数获取对象
ax = fig.add_subplot(1, 1, 1)
# 使用散点图的形式查看数据
ax.scatter(x_data, y_data)
# 保证图像可以实时刷新,而不阻塞
plt.ion()
# 展示图像
plt.show()
以上代码的运行效果如下图所示:
接下来修改训练部分的代码,用于动态刷新图表
for i in range(1000):
sess.run(trainStep, feed_dict={xs: x_data, ys: y_data})
# 修改部分
if (i % 10 == 0):
# print(sess.run(loss, feed_dict={xs: x_data, ys: y_data}))
# 尝试移除图像中原有的曲线
# 如果不移除,则多条取钱都会叠加到图表上
try:
ax.lines.remove(lines[0])
except Exception:
pass
# 获取预测数据
prediction_value = sess.run(prediction, feed_dict={xs: x_data})
# 对遇到的数据进行构图
lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
# 暂停0.1秒查看曲线的拟合情况
plt.pause(0.1)
运行结果如下图所示,应该是个动态图效果,可以自己运行代码尝试
更改之后的全部代码
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def addLayer(inputs, inSize, outSize, activationFunction = None):
Weights = tf.Variable(tf.random_normal([inSize, outSize]))
biases = tf.Variable(tf.zeros([1, outSize]) + 0.1)
Wx_plus_b = tf.matmul(inputs, Weights) + biases
if (activationFunction == None):
outputs = Wx_plus_b
else:
outputs = activationFunction(Wx_plus_b)
return outputs
x_data = np.linspace(-1, 1, 300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
l1 = addLayer(xs, 1, 10, activationFunction=tf.nn.relu)
prediction = addLayer(l1, 10, 1, activationFunction = None)
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
reduction_indices = [1]))
trainStep = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()
for i in range(1000):
sess.run(trainStep, feed_dict={xs: x_data, ys: y_data})
if (i % 10 == 0):
# print(sess.run(loss, feed_dict={xs: x_data, ys: y_data}))
try:
ax.lines.remove(lines[0])
except Exception:
pass
prediction_value = sess.run(prediction, feed_dict={xs: x_data})
# plot the prediction
lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
plt.pause(0.1)