思路:用线性回归拟合鸢尾花花瓣长度和宽度之间的关系:y = Ax + b
,其中 y 时花瓣长度,x是花瓣宽度。
建议有一点 tensorflow 基础再往下看。
下面是代码具体讲解。
先放结果吧
代码讲解
#导入库
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
# sess 我的理解就是一个发电机,让运算开始起来。要运行就用 sess.run(...)
sess = tf.Session()
iris = datasets.load_iris()
# iris.data 是一个 150 行, 4 列的二维数组,如下面截图。4 列依次是萼片长度、萼片宽度、花瓣长度、花瓣宽度
y_vals = np.array([x[2] for x in iris.data]) #花瓣长度,可以随便换
x_vals = np.array([y[3] for y in iris.data]) #花瓣宽度
# batch_size 设定一次取 10 个数据来处理
batch_size = 10
# 这里写的虽然是 x_data, y_data,但它们其实只是一个容器一样的东西,
# x_data, y_data 可以装大小为 [batch_size, 1] ,类型为 tf.float64类型的数据
x_data = tf.placeholder(shape=[batch_size, 1], dtype=tf.float64)
y_data = tf.placeholder(shape=[batch_size, 1], dtype=tf.float64)
# 声明变量,变量就是要计算或优化的对象,此处就是 A 和 b。
# 下面要计算,x_data, y_data 是二维数组,所以也把 A, b 声明成二维数组,只有一个元素的二维数组
A = tf.Variable(tf.random_normal(shape=[1,1], dtype=tf.float64))
b = tf.Variable(tf.random_normal(shape=[1,1], dtype=tf.float64))
# 输入 y = A x + b
# 损失函数为误差平方的均值
# 用梯度下降优化,优化步长为 0.01,可以取其他
y_output = tf.add(tf.multiply(A, x_data), b)
loss = tf.reduce_mean(tf.square(y_output - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)
# loss_iter 是损失值变化,作图可以看损失值随迭代次数增加的变化情况
loss_iter = []
for i in range(201):
# 产生batch_size 个范围在 0 ~ len(x_vals) 内的整数
indexs = np.random.choice(len(x_vals), size=batch_size)
# 我们每次随机选取 batch_size 个数据计算损失,进行优化
x_rand = np.transpose([x_vals[indexs]])
y_rand = np.transpose([y_vals[indexs]])
loss_iter.append(sess.run(loss, feed_dict={x_data:x_rand, y_data:y_rand}))
sess.run(optimizer, feed_dict={x_data:x_rand, y_data:y_rand})
if (i+1)%20 == 0:
print('#\t'+str(i+1) + '\tA = ' + str(sess.run(A)) + '\tb = '+str(sess.run(b)))
# 取出最终的 A 、b 值
[[AA]] = sess.run(A)
[[bb]] = sess.run(b)
y_fit = AA * x_vals + bb
# 最好玩的画图
# 画回归图
plt.subplot(211)
plt.plot(x_vals, y_vals, 'go')
plt.plot(x_vals, y_fit, 'r-')
plt.xlabel('Petal length'); plt.ylabel('Petal width'); plt.title('Petal Length and Petal Width')
# 画损失图
plt.subplot(212)
plt.plot(range(len(loss_iter)), loss_iter, 'b-')
plt.xlabel('Generation'); plt.ylabel('Loss'); plt.title('The Loss per Generation')
plt.show()