import torch
import torchvision
learing_rate = 0.1
model = torchvision.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=learing_rate,
momentum=0.9,
weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6], gamma=0.1)
# scheduler.last_epoch = 8
for epoch in range(9):
optimizer.step()
scheduler.step()
# print(optimizer.get_lr())
print(epoch, scheduler.get_lr())
三段式lr,epoch进入milestones范围内即乘以gamma,离开milestones范围之后再乘以gamma
当不指定last_epoch的结果如下:
1 0 [0.1]
2 1 [0.1]
3 2 [0.0010000000000000002]
4 3 [0.010000000000000002]
5 4 [0.010000000000000002]
6 5 [0.00010000000000000003]
7 6 [0.0010000000000000002]
8 7 [0.0010000000000000002]
9 8 [0.0010000000000000002]
当指定last_epoch=4的结果如下:
5 0 [0.1]
6 1 [0.0010000000000000002]
7 2 [0.010000000000000002]
8 3 [0.010000000000000002]
9 4 [0.010000000000000002]
10 5 [0.010000000000000002]
11 6 [0.010000000000000002]
12 7 [0.010000000000000002]
13 8 [0.010000000000000002]
当指定last_epoch=8的结果如下:
9 0 [0.1]
101 [0.1]
11 2 [0.1]
12 3 [0.1]
13 4 [0.1]
14 5 [0.1]
15 6 [0.1]
16 7 [0.1]
17 8 [0.1]