1.引入库
代码如下(示例):
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
2.读入数据
代码如下(示例):
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
model = torch.nn.Linear(2, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=2, epochs=10)
lrs = []
for i in range(20):
optimizer.step()
lrs.append(optimizer.param_groups[0]["lr"])
# print("Factor = ",i," , Learning Rate = ",optimizer.param_groups[0]["lr"])
scheduler.step()
plt.plot(lrs)
plt.show()