tensorflow线性回归--拟合iris花瓣数据

思路:用线性回归拟合鸢尾花花瓣长度和宽度之间的关系:y = Ax + b,其中 y 时花瓣长度,x是花瓣宽度。
建议有一点 tensorflow 基础再往下看。
下面是代码具体讲解。

先放结果吧

在这里插入图片描述

代码讲解

#导入库
import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
from sklearn import datasets

# sess 我的理解就是一个发电机,让运算开始起来。要运行就用 sess.run(...)
sess = tf.Session()
iris = datasets.load_iris()

# iris.data 是一个 150 行, 4 列的二维数组,如下面截图。4 列依次是萼片长度、萼片宽度、花瓣长度、花瓣宽度
y_vals  = np.array([x[2] for x in iris.data]) 	#花瓣长度,可以随便换
x_vals  = np.array([y[3] for y in iris.data])	#花瓣宽度

# batch_size 设定一次取 10 个数据来处理
batch_size = 10

# 这里写的虽然是 x_data, y_data,但它们其实只是一个容器一样的东西,
# x_data, y_data 可以装大小为 [batch_size, 1] ,类型为 tf.float64类型的数据
x_data = tf.placeholder(shape=[batch_size, 1], dtype=tf.float64)
y_data = tf.placeholder(shape=[batch_size, 1], dtype=tf.float64)

# 声明变量,变量就是要计算或优化的对象,此处就是 A 和 b。
# 下面要计算,x_data, y_data 是二维数组,所以也把 A, b 声明成二维数组,只有一个元素的二维数组
A = tf.Variable(tf.random_normal(shape=[1,1], dtype=tf.float64))
b = tf.Variable(tf.random_normal(shape=[1,1], dtype=tf.float64))

# 输入 y = A x + b
# 损失函数为误差平方的均值
# 用梯度下降优化,优化步长为 0.01,可以取其他
y_output = tf.add(tf.multiply(A, x_data), b)
loss = tf.reduce_mean(tf.square(y_output - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# loss_iter 是损失值变化,作图可以看损失值随迭代次数增加的变化情况
loss_iter = []
for i in range(201):
	# 产生batch_size 个范围在 0 ~ len(x_vals) 内的整数
	indexs = np.random.choice(len(x_vals), size=batch_size)
	# 我们每次随机选取 batch_size 个数据计算损失,进行优化
	x_rand = np.transpose([x_vals[indexs]])
	y_rand = np.transpose([y_vals[indexs]])
	
	loss_iter.append(sess.run(loss, feed_dict={x_data:x_rand, y_data:y_rand}))
	sess.run(optimizer, feed_dict={x_data:x_rand, y_data:y_rand})
	if  (i+1)%20 == 0:
		print('#\t'+str(i+1) + '\tA = ' + str(sess.run(A)) + '\tb = '+str(sess.run(b)))

# 取出最终的 A 、b 值
[[AA]] = sess.run(A)
[[bb]] = sess.run(b)
y_fit = AA * x_vals + bb

# 最好玩的画图
# 画回归图
plt.subplot(211)
plt.plot(x_vals, y_vals, 'go')
plt.plot(x_vals, y_fit, 'r-')
plt.xlabel('Petal length'); plt.ylabel('Petal width'); plt.title('Petal Length and Petal Width')

# 画损失图
plt.subplot(212)
plt.plot(range(len(loss_iter)), loss_iter, 'b-')
plt.xlabel('Generation'); plt.ylabel('Loss'); plt.title('The Loss per Generation')
plt.show()

鸢尾花部分数据

结果如上

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值