```python
import torch
import math
import argparse
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
parser = argparse.ArgumentParser(description='LR_Decay Figure')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--init_lr', '--init_learning-rate', default=0.1, type=float,
metavar='LR', help='initial (base) learning rate', dest='init_lr')
parser.add_argument('--lr_decay_rate', '--Exp-learning-rate-decay-rate', default=0.9, type=float,
metavar='LRDR', help='learning rate decay rate for Exp-decay', dest='lr_decay_rate')
parser.add_argument('--lr_decay_step', '--Exp-learning-rate-decay-step', default=10, type=float,
metavar='LRDS', help='initial learning rate decay step', dest='lr_decay_step')
args = parser.parse_args()
def adjust_learning_rate_cosine(epoch, args):
"""Decay the learning rate based on schedule"""
cur_lr = args.init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
return cur_lr
def adjust_learning_rate_exp(epoch, args):
"""Decay the learning rate based on schedule"""
cur_lr = args.init_lr * args.lr_decay_rate ** (epoch / args.lr_decay_step)
return cur_lr
def main():
args = parser.parse_args()
x_data = {}
LR_cosine = {}
LR_exp = {}
print(args.epochs)
for epoch in range(args.epochs):
print(epoch)
x_data[epoch] = str(epoch)
print(x_data[epoch])
LR_cosine[epoch] = str(adjust_learning_rate_cosine(epoch, args))
LR_exp[epoch] = str(adjust_learning_rate_exp(epoch, args))
x_data = list(x_data.values())
LR_cosine = list(LR_cosine.values())
LR_exp = list(LR_exp.values())
ln1, = plt.plot(x_data,LR_cosine,color='red',linewidth=2.0,linestyle='--')
ln2, = plt.plot(x_data,LR_exp,color='blue',linewidth=3.0,linestyle='-.')
plt.title("Comparison of different LR_decay methods.")
plt.legend(handles=[ln1, ln2], labels=['LR_cosine_decay', 'LR_exp_decay'])
ax = plt.gca()
ax.spines['right'].set_color('none') # right边框属性设置为none 不显示
ax.spines['top'].set_color('none') # top边框属性设置为none 不显示
plt.show()
print(1)
plt.show()
if __name__ == '__main__':
main()