tensorflow实例(8.2)--梯度下降法计算简单线性回归(Simple Regression Analysis)

简单回归分析(Simple Regression Analysis)定义是确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。
简单的讲就是如下图,一堆的散点图,找出一条 y=ax平方 + b 的直线最能表示这些散点图,
关于简单回归分析的简要介绍可以参考  机器学习(8)--简单线性回归(Simple Regression Analysis) 那是仅使用numpy实现算法
同样用tensorflow实现我用了两种方法
1、主要用公式实现,可参考 tensorflow实例(8.1)--公式法计算简单线性回归(Simple Regression Analysis)

2、本文采用了梯度下降法实现,因为梯度下降是一个动态过程,所以在matplotlib.pyplot显有一个动态的变化过程

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
x_data = np.array([1,3,2,1,3])
y_data = np.array([14,24,18,17,27])

#初始化不为零的a与b
a=tf.Variable(np.random.rand())
b=tf.Variable(np.random.rand())

#建立tensorflow模型
y=tf.multiply(a,x_data)+b
loss=tf.reduce_sum(tf.pow(y-y_data, 2))/(2*x_data.shape[0])
train=tf.train.GradientDescentOptimizer(0.1).minimize(loss)

#计算
sess=tf.Session()
sess.run(tf.global_variables_initializer())
abStep=[]
for i in range(600):
    sess.run(train)
    if i%50==0: 
        print("步骤:%d,  loss:%f, a:%f,  b:%f"%(i,sess.run(loss),sess.run(a),sess.run(b)))
        abStep.append((sess.run(a),sess.run(b)))
abStep.append((sess.run(a),sess.run(b)))
sess.close()

#绘图
plt.scatter(x_data,y_data)
plt.ion()
plt.show()
pltStep=None

for x in abStep:
    if pltStep!=None:pltStep[0].remove()
    x1=0
    y1=x1*x[0]+x[1]
    x2=4
    y2=x2*x[0]+x[1]
    pltStep=plt.plot([x1,x2],[y1,y2],c='r')
    plt.pause(0.5)
    


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值