PyTorch 的 torch.optim
模块提供了多种优化算法,适用于不同的深度学习任务。以下是一些常用的优化器及其特点:
1. 随机梯度下降(SGD, Stochastic Gradient Descent)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
- 特点:
- 最基本的优化算法,直接沿梯度方向更新参数。
- 可以添加
momentum
(动量)来加速收敛,避免陷入局部极小值。 - 适用于简单任务或需要精细调参的场景。
- 适用场景:
- 训练较简单的模型(如线性回归、SVM)。
- 结合学习率调度器(如
StepLR
)使用效果更好。
2. Adam(Adaptive Moment Estimation)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
- 特点:
- 自适应调整学习率,结合动量(Momentum)和 RMSProp 的优点。
- 默认学习率
lr=0.001
通常表现良好,适合大多数任务。 - 适用于大规模数据、深度网络。
- 适用场景:
- 深度学习(CNN、RNN、Transformer)。
- 当不确定用什么优化器时,Adam 通常是首选。
3. RMSProp(Root Mean Square Propagation)
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99)
- 特点:
- 自适应学习率,对梯度平方进行指数加权平均。
- 适用于非平稳目标(如 NLP、RL 任务)。
- 对学习率比较敏感,需要调参。
- 适用场景:
- 循环神经网络(RNN/LSTM)。
- 强化学习(PPO、A2C)。
4. Adagrad(Adaptive Gradient)
optimizer = torch.optim.Adagrad(model.parameters(), lr=0.01)
- 特点:
- 自适应调整学习率,对稀疏数据友好。
- 学习率会逐渐减小,可能导致训练后期更新太小。
- 适用场景:
- 推荐系统(如矩阵分解)。
- 处理稀疏特征(如 NLP 中的词嵌入)。
5. Adadelta
optimizer = torch.optim.Adadelta(model.parameters(), lr=1.0, rho=0.9)
- 特点:
- Adagrad 的改进版,不需要手动设置初始学习率。
- 适用于长时间训练的任务。
- 适用场景:
- 计算机视觉(如目标检测)。
- 当不想调学习率时可用。
6. AdamW(Adam + Weight Decay)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
- 特点:
- Adam 的改进版,更正确的权重衰减(L2 正则化)实现。
- 适用于 Transformer 等现代架构。
- 适用场景:
- BERT、GPT 等大模型训练。
- 需要正则化的任务。
7. NAdam(Nesterov-accelerated Adam)
optimizer = torch.optim.NAdam(model.parameters(), lr=0.001)
- 特点:
- 结合了 Nesterov 动量和 Adam,收敛更快。
- 适用场景:
- 需要快速收敛的任务(如 GAN 训练)。
如何选择合适的优化器?
优化器 | 适用场景 | 是否需要调参 |
---|---|---|
SGD + Momentum | 简单任务、调参敏感任务 | 需要调 lr 和 momentum |
Adam | 深度学习(CNN/RNN/Transformer) | 默认 lr=0.001 通常可用 |
RMSProp | RNN/LSTM、强化学习 | 需要调 lr 和 alpha |
Adagrad | 稀疏数据(推荐系统/NLP) | 学习率会自动调整 |
AdamW | Transformer/BERT/GPT | 适用于权重衰减任务 |
NAdam | 快速收敛(如 GAN) | 类似 Adam,但更快 |
总结
- 推荐新手使用
Adam
或AdamW
,因为它们自适应学习率,调参简单。 - 如果需要极致性能,可以尝试
SGD + Momentum
+ 学习率调度(如StepLR
或CosineAnnealingLR
)。 - RNN/LSTM 可以试试
RMSProp
。 - 大模型训练(如 BERT)优先
AdamW
。