pytorch余弦退火学习率,代码和图示 CosineAnnealingLR,CosineAnnealingWarmRestarts

pytorch的余弦退火学习率

torch.optim.lr_scheduler:调整学习率之CosineAnnealingWarmRestarts()参数说明

参考博主关于torch学习率调整的文章,写的非常好。https://blog.csdn.net/qyhaill/article/details/103043637,但是CosineAnnealingWarmRestarts()学习率调整博主没有写,在此借用博主的代码,补充一下此函数的学习率调整曲线。对于此函数,官方解释如下:链接地址https://pytorch.org/docs/stable/optim.html


函数名跟CosinAnnealingLR()相近,参数有两个不一样,T_0就是初始restart的epoch数目,T_mult就是重启之后因子,默认是1。我觉得可以这样理解,每个restart后,T_0 = T_0 * T_mult。当T_mult默认时,上代码


   
   
  1. import torch
  2. import torch.nn as nn
  3. from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
  4. import itertools
  5. import matplotlib.pyplot as plt
  6. initial_lr = 0.1
  7. class model(nn.Module):
  8. def __init__(self):
  9. super().__init__()
  10. self.conv1 = nn.Conv2d(in_channels= 3, out_channels= 3, kernel_size= 3)
  11. def forward(self, x):
  12. pass
  13. net_1 = model()
  14. optimizer_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr)
  15. scheduler_1 = CosineAnnealingWarmRestarts(optimizer_1, T_0= 5)
  16. print( "初始化的学习率:", optimizer_1.defaults[ 'lr'])
  17. lr_list = [] # 把使用过的lr都保存下来,之后画出它的变化
  18. for epoch in range( 1, 101):
  19. # train
  20. optimizer_1.zero_grad()
  21. optimizer_1.step()
  22. print( "第%d个epoch的学习率:%f" % (epoch, optimizer_1.param_groups[ 0][ 'lr']))
  23. lr_list.append(optimizer_1.param_groups[ 0][ 'lr'])
  24. scheduler_1.step()
  25. # 画出lr的变化
  26. plt.plot(list(range( 1, 101)), lr_list)
  27. plt.xlabel( "epoch")
  28. plt.ylabel( "lr")
  29. plt.title( "learning rate's curve changes as epoch goes on!")
  30. plt.show()

输出结果


   
   
  1. 初始化的学习率: 0 .1
  2. 第1个 epoch的学习率:0 .100000
  3. 第2个 epoch的学习率:0 .090451
  4. 第3个 epoch的学习率:0 .065451
  5. 第4个 epoch的学习率:0 .034549
  6. 第5个 epoch的学习率:0 .009549
  7. 第6个 epoch的学习率:0 .100000
  8. 第7个 epoch的学习率:0 .090451
  9. 第8个 epoch的学习率:0 .065451
  10. 第9个 epoch的学习率:0 .034549
  11. 第10个 epoch的学习率:0 .009549
  12. 第11个 epoch的学习率:0 .100000
  13. 第12个 epoch的学习率:0 .090451
  14. 第13个 epoch的学习率:0 .065451
  15. 第14个 epoch的学习率:0 .034549
  16. 第15个 epoch的学习率:0 .009549
  17. 第16个 epoch的学习率:0 .100000
  18. 第17个 epoch的学习率:0 .090451
  19. 第18个 epoch的学习率:0 .065451
  20. 第19个 epoch的学习率:0 .034549
  21. 第20个 epoch的学习率:0 .009549
  22. 第21个 epoch的学习率:0 .100000
  23. 第22个 epoch的学习率:0 .090451
  24. 第23个 epoch的学习率:0 .065451
  25. 第24个 epoch的学习率:0 .034549
  26. 第25个 epoch的学习率:0 .009549
  27. 第26个 epoch的学习率:0 .100000
  28. 第27个 epoch的学习率:0 .090451
  29. 第28个 epoch的学习率:0 .065451
  30. 第29个 epoch的学习率:0 .034549
  31. 第30个 epoch的学习率:0 .009549
  32. 第31个 epoch的学习率:0 .100000
  33. 第32个 epoch的学习率:0 .090451
  34. 第33个 epoch的学习率:0 .065451
  35. 第34个 epoch的学习率:0 .034549
  36. 第35个 epoch的学习率:0 .009549
  37. 第36个 epoch的学习率:0 .100000
  38. 第37个 epoch的学习率:0 .090451
  39. 第38个 epoch的学习率:0 .065451
  40. 第39个 epoch的学习率:0 .034549
  41. 第40个 epoch的学习率:0 .009549
  42. 第41个 epoch的学习率:0 .100000
  43. 第42个 epoch的学习率:0 .090451
  44. 第43个 epoch的学习率:0 .065451
  45. 第44个 epoch的学习率:0 .034549
  46. 第45个 epoch的学习率:0 .009549
  47. 第46个 epoch的学习率:0 .100000
  48. 第47个 epoch的学习率:0 .090451
  49. 第48个 epoch的学习率:0 .065451
  50. 第49个 epoch的学习率:0 .034549
  51. 第50个 epoch的学习率:0 .009549
  52. 第51个 epoch的学习率:0 .100000
  53. 第52个 epoch的学习率:0 .090451
  54. 第53个 epoch的学习率:0 .065451
  55. 第54个 epoch的学习率:0 .034549
  56. 第55个 epoch的学习率:0 .009549
  57. 第56个 epoch的学习率:0 .100000
  58. 第57个 epoch的学习率:0 .090451
  59. 第58个 epoch的学习率:0 .065451
  60. 第59个 epoch的学习率:0 .034549
  61. 第60个 epoch的学习率:0 .009549
  62. 第61个 epoch的学习率:0 .100000
  63. 第62个 epoch的学习率:0 .090451
  64. 第63个 epoch的学习率:0 .065451
  65. 第64个 epoch的学习率:0 .034549
  66. 第65个 epoch的学习率:0 .009549
  67. 第66个 epoch的学习率:0 .100000
  68. 第67个 epoch的学习率:0 .090451
  69. 第68个 epoch的学习率:0 .065451
  70. 第69个 epoch的学习率:0 .034549
  71. 第70个 epoch的学习率:0 .009549
  72. 第71个 epoch的学习率:0 .100000
  73. 第72个 epoch的学习率:0 .090451
  74. 第73个 epoch的学习率:0 .065451
  75. 第74个 epoch的学习率:0 .034549
  76. 第75个 epoch的学习率:0 .009549
  77. 第76个 epoch的学习率:0 .100000
  78. 第77个 epoch的学习率:0 .090451
  79. 第78个 epoch的学习率:0 .065451
  80. 第79个 epoch的学习率:0 .034549
  81. 第80个 epoch的学习率:0 .009549
  82. 第81个 epoch的学习率:0 .100000
  83. 第82个 epoch的学习率:0 .090451
  84. 第83个 epoch的学习率:0 .065451
  85. 第84个 epoch的学习率:0 .034549
  86. 第85个 epoch的学习率:0 .009549
  87. 第86个 epoch的学习率:0 .100000
  88. 第87个 epoch的学习率:0 .090451
  89. 第88个 epoch的学习率:0 .065451
  90. 第89个 epoch的学习率:0 .034549
  91. 第90个 epoch的学习率:0 .009549
  92. 第91个 epoch的学习率:0 .100000
  93. 第92个 epoch的学习率:0 .090451
  94. 第93个 epoch的学习率:0 .065451
  95. 第94个 epoch的学习率:0 .034549
  96. 第95个 epoch的学习率:0 .009549
  97. 第96个 epoch的学习率:0 .100000
  98. 第97个 epoch的学习率:0 .090451
  99. 第98个 epoch的学习率:0 .065451
  100. 第99个 epoch的学习率:0 .034549
  101. 第100个 epoch的学习率:0 .009549

输出图像

当T_mult设置成为2时,代码


   
   
  1. import torch
  2. import torch.nn as nn
  3. from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
  4. import itertools
  5. import matplotlib.pyplot as plt
  6. initial_lr = 0.1
  7. class model(nn.Module):
  8. def __init__(self):
  9. super().__init__()
  10. self.conv1 = nn.Conv2d(in_channels= 3, out_channels= 3, kernel_size= 3)
  11. def forward(self, x):
  12. pass
  13. net_1 = model()
  14. optimizer_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr)
  15. scheduler_1 = CosineAnnealingWarmRestarts(optimizer_1, T_0= 5, T_mult= 2)
  16. print( "初始化的学习率:", optimizer_1.defaults[ 'lr'])
  17. lr_list = [] # 把使用过的lr都保存下来,之后画出它的变化
  18. for epoch in range( 1, 101):
  19. # train
  20. optimizer_1.zero_grad()
  21. optimizer_1.step()
  22. print( "第%d个epoch的学习率:%f" % (epoch, optimizer_1.param_groups[ 0][ 'lr']))
  23. lr_list.append(optimizer_1.param_groups[ 0][ 'lr'])
  24. scheduler_1.step()
  25. # 画出lr的变化
  26. plt.plot(list(range( 1, 101)), lr_list)
  27. plt.xlabel( "epoch")
  28. plt.ylabel( "lr")
  29. plt.title( "learning rate's curve changes as epoch goes on!")
  30. plt.show()

输出结果


   
   
  1. 初始化的学习率: 0 .1
  2. 第1个 epoch的学习率:0 .100000
  3. 第2个 epoch的学习率:0 .090451
  4. 第3个 epoch的学习率:0 .065451
  5. 第4个 epoch的学习率:0 .034549
  6. 第5个 epoch的学习率:0 .009549
  7. 第6个 epoch的学习率:0 .100000
  8. 第7个 epoch的学习率:0 .097553
  9. 第8个 epoch的学习率:0 .090451
  10. 第9个 epoch的学习率:0 .079389
  11. 第10个 epoch的学习率:0 .065451
  12. 第11个 epoch的学习率:0 .050000
  13. 第12个 epoch的学习率:0 .034549
  14. 第13个 epoch的学习率:0 .020611
  15. 第14个 epoch的学习率:0 .009549
  16. 第15个 epoch的学习率:0 .002447
  17. 第16个 epoch的学习率:0 .100000
  18. 第17个 epoch的学习率:0 .099384
  19. 第18个 epoch的学习率:0 .097553
  20. 第19个 epoch的学习率:0 .094550
  21. 第20个 epoch的学习率:0 .090451
  22. 第21个 epoch的学习率:0 .085355
  23. 第22个 epoch的学习率:0 .079389
  24. 第23个 epoch的学习率:0 .072700
  25. 第24个 epoch的学习率:0 .065451
  26. 第25个 epoch的学习率:0 .057822
  27. 第26个 epoch的学习率:0 .050000
  28. 第27个 epoch的学习率:0 .042178
  29. 第28个 epoch的学习率:0 .034549
  30. 第29个 epoch的学习率:0 .027300
  31. 第30个 epoch的学习率:0 .020611
  32. 第31个 epoch的学习率:0 .014645
  33. 第32个 epoch的学习率:0 .009549
  34. 第33个 epoch的学习率:0 .005450
  35. 第34个 epoch的学习率:0 .002447
  36. 第35个 epoch的学习率:0 .000616
  37. 第36个 epoch的学习率:0 .100000
  38. 第37个 epoch的学习率:0 .099846
  39. 第38个 epoch的学习率:0 .099384
  40. 第39个 epoch的学习率:0 .098618
  41. 第40个 epoch的学习率:0 .097553
  42. 第41个 epoch的学习率:0 .096194
  43. 第42个 epoch的学习率:0 .094550
  44. 第43个 epoch的学习率:0 .092632
  45. 第44个 epoch的学习率:0 .090451
  46. 第45个 epoch的学习率:0 .088020
  47. 第46个 epoch的学习率:0 .085355
  48. 第47个 epoch的学习率:0 .082472
  49. 第48个 epoch的学习率:0 .079389
  50. 第49个 epoch的学习率:0 .076125
  51. 第50个 epoch的学习率:0 .072700
  52. 第51个 epoch的学习率:0 .069134
  53. 第52个 epoch的学习率:0 .065451
  54. 第53个 epoch的学习率:0 .061672
  55. 第54个 epoch的学习率:0 .057822
  56. 第55个 epoch的学习率:0 .053923
  57. 第56个 epoch的学习率:0 .050000
  58. 第57个 epoch的学习率:0 .046077
  59. 第58个 epoch的学习率:0 .042178
  60. 第59个 epoch的学习率:0 .038328
  61. 第60个 epoch的学习率:0 .034549
  62. 第61个 epoch的学习率:0 .030866
  63. 第62个 epoch的学习率:0 .027300
  64. 第63个 epoch的学习率:0 .023875
  65. 第64个 epoch的学习率:0 .020611
  66. 第65个 epoch的学习率:0 .017528
  67. 第66个 epoch的学习率:0 .014645
  68. 第67个 epoch的学习率:0 .011980
  69. 第68个 epoch的学习率:0 .009549
  70. 第69个 epoch的学习率:0 .007368
  71. 第70个 epoch的学习率:0 .005450
  72. 第71个 epoch的学习率:0 .003806
  73. 第72个 epoch的学习率:0 .002447
  74. 第73个 epoch的学习率:0 .001382
  75. 第74个 epoch的学习率:0 .000616
  76. 第75个 epoch的学习率:0 .000154
  77. 第76个 epoch的学习率:0 .100000
  78. 第77个 epoch的学习率:0 .099961
  79. 第78个 epoch的学习率:0 .099846
  80. 第79个 epoch的学习率:0 .099653
  81. 第80个 epoch的学习率:0 .099384
  82. 第81个 epoch的学习率:0 .099039
  83. 第82个 epoch的学习率:0 .098618
  84. 第83个 epoch的学习率:0 .098123
  85. 第84个 epoch的学习率:0 .097553
  86. 第85个 epoch的学习率:0 .096910
  87. 第86个 epoch的学习率:0 .096194
  88. 第87个 epoch的学习率:0 .095407
  89. 第88个 epoch的学习率:0 .094550
  90. 第89个 epoch的学习率:0 .093625
  91. 第90个 epoch的学习率:0 .092632
  92. 第91个 epoch的学习率:0 .091573
  93. 第92个 epoch的学习率:0 .090451
  94. 第93个 epoch的学习率:0 .089266
  95. 第94个 epoch的学习率:0 .088020
  96. 第95个 epoch的学习率:0 .086716
  97. 第96个 epoch的学习率:0 .085355
  98. 第97个 epoch的学习率:0 .083940
  99. 第98个 epoch的学习率:0 .082472
  100. 第99个 epoch的学习率:0 .080955
  101. 第100个 epoch的学习率:0 .079389

输出图像

可以看出来,当T_mult设置为2时,当epoch=5时重启依次,下一次T_0 = T_0 * T_mul此时T_0等于10,在第16次重启,下一阶段,T_0 = T_0 * T_mult 此时T_0等于20再20个epcoh重启。所以曲线重启越来越缓慢,依次在第5,5+5*2=15,15+10*2=35,35+20 * 2=75次时重启。

评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值