tensorflow实例(8.1)--公式法计算简单线性回归(Simple Regression Analysis)

简单回归分析(Simple Regression Analysis)定义是确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。
简单的讲就是如下图,一堆的散点图,找出一条 y=ax平方 + b 的直线最能表示这些散点图,
关于简单回归分析的简要介绍可以参考  机器学习(8)--简单线性回归(Simple Regression Analysis) 那是仅使用numpy实现算法
同样用tensorflow实现我用了两种方法
1、这篇主要用公式实现,
2、用了梯度下降法实现,因为梯度下降是一个动态过程,所以在matplotlib.pyplot显有一个动态的变化过程,可参考
tensorflow实例(8.2)--梯度下降法计算简单线性回归(Simple Regression Analysis)

其实这篇文章比起前面的都简单多,主要的目的还是复习一下tensorflow,这段代码里我用了tensorflow的传入量显得复杂了一点

,其实不用传入量能更简单就能实现,有兴趣不妨自己动动手


import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
x_data = np.array([1,3,2,1,3])
y_data = np.array([14,24,18,17,27])

#定义传入量
x_pl=tf.placeholder(tf.float64)
y_pl=tf.placeholder(tf.float64)

#建立a,b的tensorflow的模型
a=tf.reduce_sum(tf.multiply(x_pl-tf.reduce_mean(x_pl),y_pl-tf.reduce_mean(y_pl))) / tf.reduce_sum(tf.pow(x_pl-tf.reduce_mean(x_pl),2))
b=tf.reduce_mean(y_pl)-tf.reduce_mean(x_pl)*a

#计算
sess=tf.Session()
sess.run(tf.global_variables_initializer())
a=sess.run(a,feed_dict={x_pl:x_data,y_pl:y_data})
b=sess.run(b,feed_dict={x_pl:x_data,y_pl:y_data})
sess.close()

#绘制plt
print(a,b)
x1=0
y1=x1*a+b
x2=4
y2=x2*a+b
plt.scatter(x_data,y_data)
plt.plot([x1,x2],[y1,y2],c='r')
plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值