接着上一篇文章的例子,将简易神经网络的预测结果可视化显示:
import matplotlib.pyplot as plt
# 迭代 1000 次学习,sess.run optimizer
with tf.Session() as sess:
sess.run(init)
fig = plt.figure() # 生成一个图框
ax = fig.add_subplot(1, 1, 1) # 将图框分成1行1列,在第1个图框中画图,参数形式也可以是(111),
# 同理,参数349的意思是:将画布分割成3行4列,图像画在从左到右从上到下的第9块。
ax.scatter(x_data, y_data) # 根据输入数据绘制散点图
plt.ion() #
plt.show() # 打印散点图
for i in range(1000):
sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
if i % 50 == 0:
# training train_step 和 loss 都是由 placeholder 定义的运算,所以这里要用 feed 传入参数
# print(sess.run(loss, feed_dict={xs: x_data, ys: y_data}))
try:
ax.lines.remove(lines[0]) # 先抹除再绘制下一条,不然会有密密麻麻的线。第一次绘制时是没有之前的线的,所以要try。
except Exception:
pass
prediction_value = sess.run(prediction, feed_dict={xs: x_data}) # 我们需要预测的数据值,所以要Session一下
lines = ax.plot(x_data, prediction_value, 'r-', lw=5) # 绘制预测的数据值,
# 参数值分别表示:X轴数据、Y轴数据、以红色的线绘制、线宽为5
plt.pause(0.1) # 为了显示动态效果,暂停0.1秒再绘制下一条线
下载绘图工具包 matplotlib :pip install -i https://pypi.tuna.tsinghua.edu.cn/simple matplotlib
matplotlib API 链接。matplotlib.pyplot
提供类似MATLAB的绘图框架。
每一行代码的功能已经在程序中注释。
显示结果如下:
参考文档:
[1].https://www.jianshu.com/p/f7b4ac9159a1
[2].https://blog.csdn.net/weixin_40713373/article/details/80024583
感谢网友分享资料!如果文章出现侵权、知识错误、代码错误,请及时联系我,欢迎大家批评指正!