余弦退火法是一种用于调整神经网络训练过程中学习率的技术。它通过余弦函数来逐渐降低学习率,使得模型在训练后期能够更好地收敛,并减小震荡。以下是该方法的工作原理:
-
初始学习率设定:选择一个初始学习率 η m a x \eta_{max} ηmax和一个最小学习率 η m i n \eta_{min} ηmin 。
-
训练周期设定:设定一个周期 T m a x T_{max} Tmax,表示学习率从 η m a x \eta_{max} ηmax 到 η m i n \eta_{min} ηmin 变化的周期。
-
余弦函数形式:学习率在周期内按余弦函数逐渐变化:
η t = η m i n + 1 2 ( η m a x − η m i n ) ( 1 + cos ( t ⋅ π T m a x ) ) \eta_t = \eta_{min} + \frac{1}{2} (\eta_{max} - \eta_{min}) (1 + \cos(\frac{t \cdot \pi}{T_{max}})) ηt=ηmin+21(ηmax−ηmin)(1+cos(Tmaxt⋅π))
其中 η t \eta_t ηt 表示在第 t t t 次迭代时的学习率。
-
周期性重复:当一个周期结束后,学习率可以继续按上述方式在新周期内进行调整,或保持在 η m i n \eta_{min} ηmin 。
该方法的优势在于:在训练初期提供较高的学习率,加速模型的学习,防止其陷入局部最小;在训练后期逐渐降低学习率,使模型能够稳定收敛到其找到的全局最小的最优点,并减少过拟合的风险。
[torch.optim.lr_scheduler.CosineAnnealingLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html)
Parameters
- optimizer (Optimizer) – Wrapped optimizer.
- T_max (int) – Maximum number of iterations.
- eta_min (float) – Minimum learning rate. Default: 0.
- last_epoch (int) – The index of last epoch. Default: -1.
- verbose (bool) –
If True, prints a message to stdout for each update. Default: False.
Note:Deprecated since version 2.2: verbose is deprecated. Please use get_last_lr() to access the learning rate.
# example
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
# 定义一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 初始化模型、优化器和损失函数
model = SimpleModel()
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
# 初始化余弦退火学习率调度器
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0.0001)
# 简单的数据生成函数
def generate_data(batch_size):
inputs = torch.randn(batch_size, 10)
targets = inputs.sum(dim=1, keepdim=True)
return inputs, targets
# 模拟训练过程
for epoch in range(20): # 训练20个周期
for _ in range(100): # 每个周期内有100个批次
inputs, targets = generate_data(32) # 每个批次生成32个样本
outputs = model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 每个周期结束后更新学习率
scheduler.step()
print(f'Epoch {epoch+1}, Learning Rate: {scheduler.get_last_lr()[0]}, Loss: {loss.item()}')