原教程地址
这一节利用matplot库讲了可视化显示真实数据与预测数据
代码根据上一篇的内容进行添加和修改,为了方便看,我先只把修改的部分代码贴出来:
原代码及结果
1.首先引入matplot库
import matplotlib.pyplot as plt
2.然后显示真实数据
###可视化显示真实数据###
fig = plt.figure() #生成图片框#
ax = fig.add_subplot(1,1,1) #连续画图用axis,1,1,1是编号
ax.scatter(x_data,y_data)#显示真实数据
plt.show()#调用这句会将程序暂停
得到结果如图所示
3显示连续的预测数据
对原来代码的if部分进行修改,注释掉原来打印语句:
if i%50 == 0:
###to see the step improvement###
#print(sess.run(loss,feed_dict={xs:x_data,ys:y_data}))
try:
ax.lines.remove(lines[0])#去除掉lines的第一个线段,防止产生很多线段导致看不清结果
except Exception:
pass
prediction_value = sess.run(prediction,feed_dict={
xs:x_data})
lines = ax.plot(x_data,prediction_value,'r-',lw=5)#用线显示,红色,宽度为5
plt