深度学习:Tensorflow实现线性回归梯度下降优化

版权声明:本文为博主原创文章,欢迎转载,请注明出处 https://blog.csdn.net/mouday/article/details/88096367

回顾

1、算法:线性回归

y=kx+by = kx + b

2、策略:均方误差

3、优化:梯度下降

步骤

1、准备好特征值和目标值

2、建立模型,随机初始化准备权重w和偏置b

y_predict=xw+by\_predict = xw + b

3、求损失函数,误差,均方误差
mse=(y1y1)2+(y2y2))2nmse = \frac{(y1-y1^-)^2 + (y2-y2^-))^2 }{n}

4、梯度下降去优化损失过程,指定学习率

矩阵相乘
(m行,n列) * (n 行, 1列) = (m行, 1列) + 偏置

TensorFlow运算API

# 矩阵运算
tf.matmul(x, w)

# 平方
tf.square(error)

# 均值
tf.reduce_mean(error)

梯度下降API

tf.train.GradientDescentOptimizer(learning_rate)

参数:
learning_rate 学习率
方法:
minimize(loss)

return 梯度下降op

tips:模型参数必须用变量定义

代码实现

# -*- coding: utf-8 -*-

"""
实现一个线性回归预测
"""

import tensorflow as tf

# 1、准备数据,x 特征值[100, 1] y 目标值 [100]
x = tf.random_normal((100, 1), mean=1.75, stddev=0.5, name="x_data")

# 矩阵相乘必须是二维的
y_true = tf.matmul(x, [[0.7]]) + 0.8

# 2、建立线性回归模型 1个特征,1个权重,1个偏置 y = xw + b
# 随机给一个权重和偏置的值,让他们去计算损失,然后在当前状态下优化
# 用变量定义才能优化
weight = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name="weight")
bias = tf.Variable(0.0, name="bias")

y_predict = tf.matmul(x, weight) + bias

# 3、建立损失函数,均方误差
loss = tf.reduce_mean(tf.square(y_true - y_predict))

# 4、梯度下降优化损失,学习率learn_rate 0,1,2,3,5,7,10
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# 定义一个初始化变量的op
init_op = tf.global_variables_initializer()


# 通过会话运行程序
with tf.Session() as sess:
    # 初始化变量
    sess.run(init_op)

    # 打印随机最先初始化的权重和偏置
    print("初始化的参数权重:{}, 偏置:{}".format(weight.eval(), bias.eval()))

    # 运行优化
    for i in range(200):
        sess.run(train_op)

        print("第 {} 次优化 参数权重:{}, 偏置:{}".format(i, weight.eval(), bias.eval()))

计算结果

初始化的参数权重:[[0.313286]], 偏置:0.0
第 0 次优化 参数权重:[[0.86829025]], 偏置:0.2980913817882538
第 1 次优化 参数权重:[[0.9330569]], 偏置:0.3393169939517975
第 2 次优化 参数权重:[[0.9391256]], 偏置:0.34980931878089905
第 3 次优化 参数权重:[[0.9435929]], 偏置:0.35885265469551086
...
第 197 次优化 参数权重:[[0.7236507]], 偏置:0.7558320760726929
第 198 次优化 参数权重:[[0.7237728]], 偏置:0.7565979361534119
第 199 次优化 参数权重:[[0.72236484]], 偏置:0.7566843032836914

计算越多,计算结果越接近真实值

没有更多推荐了,返回首页