torch.optim.lr_scheduler:调整学习率之CosineAnnealingWarmRestarts()参数说明
参考博主关于torch学习率调整的文章,写的非常好。https://blog.csdn.net/qyhaill/article/details/103043637,但是CosineAnnealingWarmRestarts()学习率调整博主没有写,在此借用博主的代码,补充一下此函数的学习率调整曲线。对于此函数,官方解释如下:链接地址https://pytorch.org/docs/stable/optim.html
函数名跟CosinAnnealingLR()相近,参数有两个不一样,T_0就是初始restart的epoch数目,T_mult就是重启之后因子,默认是1。我觉得可以这样理解,每个restart后,T_0 = T_0 * T_mult。当T_mult默认时,上代码
-
import torch
-
import torch.nn
as nn
-
from torch.optim.lr_scheduler
import CosineAnnealingLR, CosineAnnealingWarmRestarts
-
import itertools
-
-
import matplotlib.pyplot
as plt
-
-
initial_lr =
0.1
-
-
-
class model(nn.Module):
-
def __init__(self):
-
super().__init__()
-
self.conv1 = nn.Conv2d(in_channels=
3, out_channels=
3, kernel_size=
3)
-
-
def forward(self, x):
-
pass
-
-
-
net_1 = model()
-
-
optimizer_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr)
-
scheduler_1 = CosineAnnealingWarmRestarts(optimizer_1, T_0=
5)
-
-
print(
"初始化的学习率:", optimizer_1.defaults[
'lr'])
-
-
lr_list = []
# 把使用过的lr都保存下来,之后画出它的变化
-
-
for epoch
in range(
1,
101):
-
# train
-
-
optimizer_1.zero_grad()
-
optimizer_1.step()
-
print(
"第%d个epoch的学习率:%f" % (epoch, optimizer_1.param_groups[
0][
'lr']))
-
lr_list.append(optimizer_1.param_groups[
0][
'lr'])
-
scheduler_1.step()
-
-
# 画出lr的变化
-
plt.plot(list(range(
1,
101)), lr_list)
-
plt.xlabel(
"epoch")
-
plt.ylabel(
"lr")
-
plt.title(
"learning rate's curve changes as epoch goes on!")
-
plt.show()
输出结果
-
初始化的学习率: 0
.1
-
第1个
epoch的学习率:0
.100000
-
第2个
epoch的学习率:0
.090451
-
第3个
epoch的学习率:0
.065451
-
第4个
epoch的学习率:0
.034549
-
第5个
epoch的学习率:0
.009549
-
第6个
epoch的学习率:0
.100000
-
第7个
epoch的学习率:0
.090451
-
第8个
epoch的学习率:0
.065451
-
第9个
epoch的学习率:0
.034549
-
第10个
epoch的学习率:0
.009549
-
第11个
epoch的学习率:0
.100000
-
第12个
epoch的学习率:0
.090451
-
第13个
epoch的学习率:0
.065451
-
第14个
epoch的学习率:0
.034549
-
第15个
epoch的学习率:0
.009549
-
第16个
epoch的学习率:0
.100000
-
第17个
epoch的学习率:0
.090451
-
第18个
epoch的学习率:0
.065451
-
第19个
epoch的学习率:0
.034549
-
第20个
epoch的学习率:0
.009549
-
第21个
epoch的学习率:0
.100000
-
第22个
epoch的学习率:0
.090451
-
第23个
epoch的学习率:0
.065451
-
第24个
epoch的学习率:0
.034549
-
第25个
epoch的学习率:0
.009549
-
第26个
epoch的学习率:0
.100000
-
第27个
epoch的学习率:0
.090451
-
第28个
epoch的学习率:0
.065451
-
第29个
epoch的学习率:0
.034549
-
第30个
epoch的学习率:0
.009549
-
第31个
epoch的学习率:0
.100000
-
第32个
epoch的学习率:0
.090451
-
第33个
epoch的学习率:0
.065451
-
第34个
epoch的学习率:0
.034549
-
第35个
epoch的学习率:0
.009549
-
第36个
epoch的学习率:0
.100000
-
第37个
epoch的学习率:0
.090451
-
第38个
epoch的学习率:0
.065451
-
第39个
epoch的学习率:0
.034549
-
第40个
epoch的学习率:0
.009549
-
第41个
epoch的学习率:0
.100000
-
第42个
epoch的学习率:0
.090451
-
第43个
epoch的学习率:0
.065451
-
第44个
epoch的学习率:0
.034549
-
第45个
epoch的学习率:0
.009549
-
第46个
epoch的学习率:0
.100000
-
第47个
epoch的学习率:0
.090451
-
第48个
epoch的学习率:0
.065451
-
第49个
epoch的学习率:0
.034549
-
第50个
epoch的学习率:0
.009549
-
第51个
epoch的学习率:0
.100000
-
第52个
epoch的学习率:0
.090451
-
第53个
epoch的学习率:0
.065451
-
第54个
epoch的学习率:0
.034549
-
第55个
epoch的学习率:0
.009549
-
第56个
epoch的学习率:0
.100000
-
第57个
epoch的学习率:0
.090451
-
第58个
epoch的学习率:0
.065451
-
第59个
epoch的学习率:0
.034549
-
第60个
epoch的学习率:0
.009549
-
第61个
epoch的学习率:0
.100000
-
第62个
epoch的学习率:0
.090451
-
第63个
epoch的学习率:0
.065451
-
第64个
epoch的学习率:0
.034549
-
第65个
epoch的学习率:0
.009549
-
第66个
epoch的学习率:0
.100000
-
第67个
epoch的学习率:0
.090451
-
第68个
epoch的学习率:0
.065451
-
第69个
epoch的学习率:0
.034549
-
第70个
epoch的学习率:0
.009549
-
第71个
epoch的学习率:0
.100000
-
第72个
epoch的学习率:0
.090451
-
第73个
epoch的学习率:0
.065451
-
第74个
epoch的学习率:0
.034549
-
第75个
epoch的学习率:0
.009549
-
第76个
epoch的学习率:0
.100000
-
第77个
epoch的学习率:0
.090451
-
第78个
epoch的学习率:0
.065451
-
第79个
epoch的学习率:0
.034549
-
第80个
epoch的学习率:0
.009549
-
第81个
epoch的学习率:0
.100000
-
第82个
epoch的学习率:0
.090451
-
第83个
epoch的学习率:0
.065451
-
第84个
epoch的学习率:0
.034549
-
第85个
epoch的学习率:0
.009549
-
第86个
epoch的学习率:0
.100000
-
第87个
epoch的学习率:0
.090451
-
第88个
epoch的学习率:0
.065451
-
第89个
epoch的学习率:0
.034549
-
第90个
epoch的学习率:0
.009549
-
第91个
epoch的学习率:0
.100000
-
第92个
epoch的学习率:0
.090451
-
第93个
epoch的学习率:0
.065451
-
第94个
epoch的学习率:0
.034549
-
第95个
epoch的学习率:0
.009549
-
第96个
epoch的学习率:0
.100000
-
第97个
epoch的学习率:0
.090451
-
第98个
epoch的学习率:0
.065451
-
第99个
epoch的学习率:0
.034549
-
第100个
epoch的学习率:0
.009549
输出图像
当T_mult设置成为2时,代码
-
import torch
-
import torch.nn
as nn
-
from torch.optim.lr_scheduler
import CosineAnnealingLR, CosineAnnealingWarmRestarts
-
import itertools
-
-
import matplotlib.pyplot
as plt
-
-
initial_lr =
0.1
-
-
-
class model(nn.Module):
-
def __init__(self):
-
super().__init__()
-
self.conv1 = nn.Conv2d(in_channels=
3, out_channels=
3, kernel_size=
3)
-
-
def forward(self, x):
-
pass
-
-
-
net_1 = model()
-
-
optimizer_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr)
-
scheduler_1 = CosineAnnealingWarmRestarts(optimizer_1, T_0=
5, T_mult=
2)
-
-
print(
"初始化的学习率:", optimizer_1.defaults[
'lr'])
-
-
lr_list = []
# 把使用过的lr都保存下来,之后画出它的变化
-
-
for epoch
in range(
1,
101):
-
# train
-
-
optimizer_1.zero_grad()
-
optimizer_1.step()
-
print(
"第%d个epoch的学习率:%f" % (epoch, optimizer_1.param_groups[
0][
'lr']))
-
lr_list.append(optimizer_1.param_groups[
0][
'lr'])
-
scheduler_1.step()
-
-
# 画出lr的变化
-
plt.plot(list(range(
1,
101)), lr_list)
-
plt.xlabel(
"epoch")
-
plt.ylabel(
"lr")
-
plt.title(
"learning rate's curve changes as epoch goes on!")
-
plt.show()
输出结果
-
初始化的学习率: 0
.1
-
第1个
epoch的学习率:0
.100000
-
第2个
epoch的学习率:0
.090451
-
第3个
epoch的学习率:0
.065451
-
第4个
epoch的学习率:0
.034549
-
第5个
epoch的学习率:0
.009549
-
第6个
epoch的学习率:0
.100000
-
第7个
epoch的学习率:0
.097553
-
第8个
epoch的学习率:0
.090451
-
第9个
epoch的学习率:0
.079389
-
第10个
epoch的学习率:0
.065451
-
第11个
epoch的学习率:0
.050000
-
第12个
epoch的学习率:0
.034549
-
第13个
epoch的学习率:0
.020611
-
第14个
epoch的学习率:0
.009549
-
第15个
epoch的学习率:0
.002447
-
第16个
epoch的学习率:0
.100000
-
第17个
epoch的学习率:0
.099384
-
第18个
epoch的学习率:0
.097553
-
第19个
epoch的学习率:0
.094550
-
第20个
epoch的学习率:0
.090451
-
第21个
epoch的学习率:0
.085355
-
第22个
epoch的学习率:0
.079389
-
第23个
epoch的学习率:0
.072700
-
第24个
epoch的学习率:0
.065451
-
第25个
epoch的学习率:0
.057822
-
第26个
epoch的学习率:0
.050000
-
第27个
epoch的学习率:0
.042178
-
第28个
epoch的学习率:0
.034549
-
第29个
epoch的学习率:0
.027300
-
第30个
epoch的学习率:0
.020611
-
第31个
epoch的学习率:0
.014645
-
第32个
epoch的学习率:0
.009549
-
第33个
epoch的学习率:0
.005450
-
第34个
epoch的学习率:0
.002447
-
第35个
epoch的学习率:0
.000616
-
第36个
epoch的学习率:0
.100000
-
第37个
epoch的学习率:0
.099846
-
第38个
epoch的学习率:0
.099384
-
第39个
epoch的学习率:0
.098618
-
第40个
epoch的学习率:0
.097553
-
第41个
epoch的学习率:0
.096194
-
第42个
epoch的学习率:0
.094550
-
第43个
epoch的学习率:0
.092632
-
第44个
epoch的学习率:0
.090451
-
第45个
epoch的学习率:0
.088020
-
第46个
epoch的学习率:0
.085355
-
第47个
epoch的学习率:0
.082472
-
第48个
epoch的学习率:0
.079389
-
第49个
epoch的学习率:0
.076125
-
第50个
epoch的学习率:0
.072700
-
第51个
epoch的学习率:0
.069134
-
第52个
epoch的学习率:0
.065451
-
第53个
epoch的学习率:0
.061672
-
第54个
epoch的学习率:0
.057822
-
第55个
epoch的学习率:0
.053923
-
第56个
epoch的学习率:0
.050000
-
第57个
epoch的学习率:0
.046077
-
第58个
epoch的学习率:0
.042178
-
第59个
epoch的学习率:0
.038328
-
第60个
epoch的学习率:0
.034549
-
第61个
epoch的学习率:0
.030866
-
第62个
epoch的学习率:0
.027300
-
第63个
epoch的学习率:0
.023875
-
第64个
epoch的学习率:0
.020611
-
第65个
epoch的学习率:0
.017528
-
第66个
epoch的学习率:0
.014645
-
第67个
epoch的学习率:0
.011980
-
第68个
epoch的学习率:0
.009549
-
第69个
epoch的学习率:0
.007368
-
第70个
epoch的学习率:0
.005450
-
第71个
epoch的学习率:0
.003806
-
第72个
epoch的学习率:0
.002447
-
第73个
epoch的学习率:0
.001382
-
第74个
epoch的学习率:0
.000616
-
第75个
epoch的学习率:0
.000154
-
第76个
epoch的学习率:0
.100000
-
第77个
epoch的学习率:0
.099961
-
第78个
epoch的学习率:0
.099846
-
第79个
epoch的学习率:0
.099653
-
第80个
epoch的学习率:0
.099384
-
第81个
epoch的学习率:0
.099039
-
第82个
epoch的学习率:0
.098618
-
第83个
epoch的学习率:0
.098123
-
第84个
epoch的学习率:0
.097553
-
第85个
epoch的学习率:0
.096910
-
第86个
epoch的学习率:0
.096194
-
第87个
epoch的学习率:0
.095407
-
第88个
epoch的学习率:0
.094550
-
第89个
epoch的学习率:0
.093625
-
第90个
epoch的学习率:0
.092632
-
第91个
epoch的学习率:0
.091573
-
第92个
epoch的学习率:0
.090451
-
第93个
epoch的学习率:0
.089266
-
第94个
epoch的学习率:0
.088020
-
第95个
epoch的学习率:0
.086716
-
第96个
epoch的学习率:0
.085355
-
第97个
epoch的学习率:0
.083940
-
第98个
epoch的学习率:0
.082472
-
第99个
epoch的学习率:0
.080955
-
第100个
epoch的学习率:0
.079389
输出图像
可以看出来,当T_mult设置为2时,当epoch=5时重启依次,下一次T_0 = T_0 * T_mul此时T_0等于10,在第16次重启,下一阶段,T_0 = T_0 * T_mult 此时T_0等于20再20个epcoh重启。所以曲线重启越来越缓慢,依次在第5,5+5*2=15,15+10*2=35,35+20 * 2=75次时重启。