tensorflow 实现 正规方程求解线性回归

正规方程是用一个方程去求解线性回归问题,是最小二乘法的矩阵形式。在吴恩达的机器学习课程上也有提及。

这里写一个简单的用tensorflow实现的方法。参考《Tensorflow机器学习实战指南》

# -*- coding: utf-8 -*-
"""
Created on Sat Dec  9 20:16:39 2017

@author: www
"""

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np


#准备数据
x_vals = np.linspace(0, 10, 100)
y_vals = 2 * x_vals + np.random.normal(0, 1, 100)


sess = tf.Session()

x_vals_column = np.transpose(np.matrix(x_vals))
one_column = np.transpose(np.matrix(np.repeat(1, 100)))
A = np.column_stack((x_vals_column, one_column))
b = np.transpose(np.matrix(y_vals))

#转化为张量
A_tensor = tf.constant(A)
b_tensor = tf.constant(b)

#使用正规方程法
A_temp = tf.matmul(tf.transpose(A_tensor), A_tensor)
A_temp = tf.matrix_inverse(A_temp)
A_temp = tf.matmul(A_temp, tf.transpose(A_tensor))
solution = tf.matmul(A_temp, b_tensor)


solution_eval = sess.run(solution)

#得到系数,截距
slope = solution_eval[0][0]
intercept = solution_eval[1][0]

print('slope'+str(slope))
print('intercept'+str(intercept))


#画图显示
best_fit = []

for i in x_vals:
    best_fit.append(slope * i + intercept)
plt.plot(x_vals, y_vals, 'o', label ='Data')
plt.plot(x_vals, best_fit, 'r-', label ='Best fit line')
plt.legend(loc = 'upper left')
plt.show()










  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值