Tensorflow实现线性回归
之前有一篇文章已经写过如何使用tensorflow实现linear regerssion,使用的是官方提供的例程,这次我在原来基础上改写了一些问题,并加入了tensorboard可视化的一些小应用。
首先是样例数据人工生成,这里为了表达直观,我就以一维的数据为例。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
N = 200 # 样本点数目
x = np.linspace(-1, 1, N)
y = 2.0*x + np.random.standard_normal(x.shape)*0.3+0.5 # 生成线性数据
x = x.reshape([N, 1]) # 转换一下格式,准备feed进placeholder
y = y.reshape([N, 1])
plt.scatter(x, y)
plt.plot(x, 2*x+0.5)
plt.show()
然后就是建立计算图的过程,线性回归和之前聚类算法k-means不太一样,这类算法更适合使用tf来做,因为流程大致相同。
- 声明输入变量,数据规模小可以用placeholder占位,如果数据规模较大就需要用其余数据读取方式来驱动计算了。
- 声明运算中要是用的参数遍历。
- 定义运算节点op。
- 定义误差函数loss。
调用优化器函数优化loss。
按照以上流程,我们开始写出计算图。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# 建图
inputx = tf.placeholder(dtype=tf.float32, shape=[None, 1])
groundY = tf.placeholder(dtype=tf.float32, shape=[None, 1])
W = tf.Variable(tf.random_normal([1, 1], stddev=0.01))
b = tf.Variable(tf.random_normal([1], stddev=0.01))
pred = tf.matmul(inputx, W)+b
loss = tf.reduce_sum(tf.pow(pred-groundY, 2))
# 优化目标函数
train = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
# 加入监控点
tf.summary.scalar("loss", loss)
merged = tf.summary.merge_all()
# 初始化所有变量
init = tf.global_variables_initializer()
with tf.Session() as sess:
# 定义日志文件
writer = tf.summary.FileWriter("./log/", sess.graph)
sess.run(init)
for i in range(20):
sess.run(train,feed_dict={inputx:x, groundY:y})
predArr, lossArr = sess.run([pred, loss], feed_dict={inputx:x, groundY:y})
print(lossArr)
summary_str = sess.run(merged, feed_dict={inputx:x, groundY:y})
writer.add_summary(summary_str, i) # 向日志文件写入监控点数据
# 作图观察
WArr, bArr = sess.run([W, b])
print(WArr, bArr)
plt.scatter(x, y)
plt.plot(x, WArr * x + bArr)
plt.show()
可以看到,线性回归虽然简单,但是五脏俱全,这个例子基本涵盖了所有过程。每一步迭代过程都作图显示,这样可以清楚地看到回归直线的收敛情况。
另外就是tensorboard的使用,其基本原理是,程序运行在发生的事件可以在图中被定义出来,这些事件会被写入一个日志文件中,tensorboard会读取这个文件,并在本地服务器上显示出来。归结起来是如下几点。
- tf.summary.scalar操作是定义在哪个变量上设立check point,除此之外还有其余的summary类型。
- tf.summary.merge_all操作是汇集所有check point。
- writer = tf.summary.FileWriter(“./log/”, sess.graph)操作是定义输出的日志文件的路径,一般要在session中定义。
- summary_str = sess.run(merged, feed_dict={inputx:x, groundY:y}),运行汇集操作,序列化后的形成字符串。
- writer.add_summary(summary_str, i) 把序列化后的字符串写入log。
- 然后运行tensorboard –logdir=/path/to/log-directory打开本地服务器,在浏览器(最好是chrome)输入127.0.0.1:6006即可看到tensorboard。后面的文件路径就是第三步log文件所在路径,window平台下这里有一个bug,路径默认盘符一律是c盘,无法更改,即使你的log定义在了其他位置。