【深度学习技巧】学习率-余弦退火

1. 理解

1.1. Warm up

由于一开始参数不稳定,梯度较大,如果此时学习率设置过大可能导致数值不稳定。使用warm up有助于减缓模型在初始阶段对mini-batch的提前过拟合现象,保持分布的平稳,其次也有助于保持模型深层的稳定性。

1.2 Cosine Anneal

梯度下降算法优化目标函数,使得loss值接近全局最小值,学习率应当变得更小,才能够使其更加容易实现。余弦退火首先使得学习率先缓慢下降,然后再快速下降,可以满足上面的需求。
公式如下:
在这里插入图片描述
其中,ηmax为学习率最大值,ηmin为最小值,Tcur为当前轮次,Tmax为半个周期。

1.3 余弦退火结合

其示意图如下所示:
在这里插入图片描述

2. 代码

import math
import torch 
from math import cos, pi

def warmup_cosine(optimizer, current_epoch, max_epoch, lr_min=0, lr_max=0.1, warmup_epoch = 10):
    if current_epoch < warmup_epoch:
        lr = lr_max * current_epoch / warmup_epoch
    else:
        lr = lr = lr_min + (lr_max-lr_min)*(1 + cos(pi * (current_epoch - warmup_epoch) / (max_epoch - warmup_epoch))) / 2
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    

# 学习率的设置
lr_max = 0.01
lr_min = 0.0001
max_epoch = 1000
warmup_epoch = 250

for epoch in range(max_epoch):
    #训练
    warmup_cosine(optimizer=optimizer,current_epoch=epoch, max_epoch=max_epoch, 	lr_min=lr_min,lr_max=lr_max, warmup_epoch = warmup_epoch)
    print(optimizer.param_groups[0]['lr'])
    cnt = 0
    losssum = 0
    
    model.train()
    for image, label, label_p in tqdm(dataloader):
        optimizer.zero_grad()
        image, label, label_p = image.to(device).float(), label.to(device), label_p.to(device)

        out = model(image)   # torch.Size([10, 2, 2])
        
        #计算loss
        loss = mse(out, label_p)  # label_p : torch.Size([10, 2, 2])  batchsize 10
        loss.backward()
        optimizer.step()

3. 学习率的总结

https://blog.csdn.net/qq_35091353/article/details/117322293

4. 参考

https://blog.csdn.net/qq_40268672/article/details/121145630
https://cloud.tencent.com/developer/article/1749002

  • 2
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

只搬烫手的砖

你的鼓励将是我最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值