1、引言
小屌丝:鱼哥, 现在还有些时间, 在讲一讲SGD呗
小鱼:… 我可以反驳吗?
小屌丝:可以啊, 反正520那天的情况,我也不会跟其他人说的
小鱼:你… 太好了。
小屌丝:是吧, 必须得。
小鱼:我们还是聊一聊SGD吧
小屌丝:先别急,慢慢聊
小鱼:你… 可以啊。
小屌丝:反正…嘿嘿
小鱼:忍一忍, 过了520,看你还能奈我何。
小屌丝:那这几天,我可不客气了。
2、随机梯度下降(SGD)
2.1 定义
随机梯度下降是一种优化算法,用于最小化目标函数,即减少模型预测和实际结果之间的差距。
它是梯度下降算法的一种变体,主要区别在于每次迭代只使用一个数据点来更新参数,而不是使用整个数据集。
这种方法可以显著加快计算速度,并使算法能够处理大规模数据集。
2.2 核心原理
SGD的核心原理是利用每个数据点的梯度(或者一小批数据点的平均梯度)来逐步调整模型参数,以求达到最小化损失函数的目的。
在每次迭代中,算法随机选择一个样本(或一小批样本),计算该样本的梯度,然后用这个梯度更新模型参数
2.3 实现方式
- 初始化:选择初始参数值。
- 迭代:直到满足停止条件(如达到最大迭代次数或梯度变化小于某个阈值)。
- 随机选择一个样本(或一小批样本)。
- 计算选中样本的梯度。
- 更新模型参数。
2.4 算法公式
假设
(
L
)
(L)
(L)是损失函数,
(
θ
)
(\theta)
(θ)是模型参数,
(
η
)
(\eta)
(η)是学习率,
(
g
)
(g)
(g)是计算得到的梯度,更新公式如下:
[
θ
=
θ
−
η
⋅
g
]
[ \theta = \theta - \eta \cdot g ]
[θ=θ−η⋅g]其中,
(
g
)
(g)
(g)是对单个样本(或一小批样本)计算得到的梯度。
2.5 代码示例
# -*- coding:utf-8 -*-
# @Time : 2024-05-16
# @Author : Carl_DJ
import numpy as np
# 生成模拟数据
np.random.seed(42)
x = 2 * np.random.rand(100, 1)
y = 4 + 3 * x + np.random.randn(100, 1)
# 初始化参数
w = np.random.randn(1, 1)
b = np.random.randn(1, 1)
# 学习率
learning_rate = 0.01
# 执行SGD的迭代次数
n_iterations = 1000
# 批量大小为1表示随机梯度下降
batch_size = 1
# 随机梯度下降
for iteration in range(n_iterations):
for i in range(0, x.shape[0], batch_size):
random_index = np.random.randint(0, x.shape[0])
xi = x[random_index:random_index+1]
yi = y[random_index:random_index+1]
# 计算梯度
gradients_w = -2 * xi.T.dot(yi - (w * xi + b))
gradients_b = -2 * np.sum(yi - (w * xi + b))
# 更新参数
w = w - learning_rate * gradients_w
b = b - learning_rate * gradients_b
print("最终参数 w:", w, "b:", b)
# 预测函数
def predict(x):
return w * x + b
# 使用模型进行预测
x_new = np.array([[0], [2]])
y_predict = predict(x_new)
print("预测结果:", y_predict)
解析
- 首先生成一组线性关系的模拟数据,然后初始化模型参数 (w) 和 (b)。
- 在随机梯度下降的每次迭代中,随机选择一个样本(由于batch_size=1),计算梯度,并更新参数。
- 通过多次迭代,模型参数将逐渐逼近最优解。
3、总结
随机梯度下降是一种非常有效的优化算法,尤其适用于大规模数据集。
通过在每次迭代中仅使用一个样本(或一小批样本)来更新模型参数,SGD可以显著减少计算量,加快训练速度。
然而,由于每次更新只使用一个样本,这可能会导致算法在优化路径上震荡,因此通常需要配合学习率衰减等策略来确保收敛。
我是小鱼:
- CSDN 博客专家;
- 阿里云 专家博主;
- 51CTO博客专家;
- 企业认证金牌面试官;
- 多个名企认证&特邀讲师等;
- 名企签约职场面试培训、职场规划师;
- 多个国内主流技术社区的认证专家博主;
- 多款主流产品(阿里云等)评测一等奖获得者;
关注小鱼,学习【机器学习】&【深度学习】领域的知识。