import math
from torch.utils.tensorboard import SummaryWriter
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
# 1e-4 1e-6 100 3 1e-5 5 epoch
if iters <= warmup_total_iters: #当前epoch小于等于3
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
elif iters >= total_iters - no_aug_iter: #当前epoch大于等于95
lr = min_lr
else: #当前epoch介于4-94之间的
lr = min_lr + 0.5 * (lr - min_lr) * (
1.0 + math.cos(
math.pi * (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
)
return lr
lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter = 1e-4, 1e-6, 100, 3, 1e-5, 5
writer = SummaryWriter("logaaa")
for iters in range (100):
y = yolox_warm_cos_lr(1e-4, 1e-6, 100, 3, 1e-5, 5, iters)
writer.add_scalar("lr_cos",y,iters)
writer.close()