Tensorflow 入门学习6 梯度下降算法

梯度下降法是一个一阶最优化算法,通常也称为最速下降法。要使用梯度下降法找到一个函数的局部极小值,必须向函数上当前点对于梯度(或者是近似梯度)的反方向的规定步长距离点进行迭代搜索。所以梯度下降法可以帮助我们求解某个函数的极小值或者最小值。对于n维问题就最优解,梯度下降法是最常用的方法之一。

参考来源:https://segmentfault.com/a/1190000011994447

其原理推导可以参考:
https://www.jianshu.com/p/c7e642877b0e

上一节演示的一个隐藏层的神经网络实现回归分析,使用的就是梯度下降法。这里再用程序可视化展示梯度下降的过程。

示例

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 学习率
lr = 0.1
real_params = [2.3, 4.5]  # 真正的参数

tf_X = tf.placeholder(tf.float32, [None, 1])
tf_y = tf.placeholder(tf.float32, [None, 1])
weight = tf.Variable(initial_value=[[5]], dtype=tf.float32)
bia = tf.Variable(initial_value=[[4]], dtype=tf.float32)
y = tf.matmul(tf_X, weight) + bia

# 损失函数计算,使用均方误差
loss = tf.losses.mean_squared_error(tf_y, y)
# 梯度下降
train_op = tf.train.GradientDescentOptimizer(lr).minimize(loss)

X_data = np.linspace(-1, 1, 200)[:, np.newaxis]
noise = np.random.normal(0, 0.1, X_data.shape)
# 生成噪声
y_data = X_data * real_params[0] + real_params[1] + noise

sess = tf.Session()
sess.run(tf.global_variables_initializer())

weights = []
biases = []
losses = []
# 训练
for step in range(400):
    w, b, cost, _ = sess.run([weight, bia, loss, train_op],
                             feed_dict={tf_X: X_data, tf_y: y_data})
    weights.append(w)
    biases.append(b)
    losses.append(cost)
result = sess.run(y, feed_dict={tf_X: X_data, tf_y: y_data})

# 画出拟合图像
plt.figure(1)
plt.scatter(X_data, y_data, color='r', alpha=0.5)
plt.plot(X_data, result, lw=3)

# 画梯度下降示例
fig = plt.figure(2)
ax_3d = Axes3D(fig)
w_3d, b_3d = np.meshgrid(np.linspace(-2, 7, 30), np.linspace(-2, 7, 30))
loss_3d = np.array(
    [np.mean(np.square((X_data * w_ + b_) - y_data))
     for w_, b_ in zip(w_3d.ravel(), b_3d.ravel())]).reshape(w_3d.shape)
ax_3d.plot_surface(w_3d, b_3d, loss_3d, cmap=plt.get_cmap('rainbow'))
weights = np.array(weights).ravel()
biases = np.array(biases).ravel()

# 描绘初始点
ax_3d.scatter(weights[0], biases[0], losses[0], s=30, color='r')
ax_3d.set_xlabel('w')
ax_3d.set_ylabel('b')
ax_3d.plot(weights, biases, losses, lw=3, c='r')
plt.show()

结果:
在这里插入图片描述

本文代码来自:
https://blog.csdn.net/winycg/article/details/78524685

代码

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

编程圈子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值