Torch 训练的一些注意点:优化器params参数 与 lr 调度器篇 [1]

摘要:

  • Torch 多模型的参数分组
  • lr_scheduler.StepLR

我的 Torch 版本

>>> import torch
>>> torch.__version__
'1.8.1+cu111'

截止到目前为止,已经 1.11 了,Torch这个更新速度好快啊啊啊


先建立一个实验环境
import torch
import torch.optim as optim 
import torch.nn as nn


# 这里先建立两个模型
model1 = nn.Conv1d(1, 1, 1, bias=False)
model2 = nn.Conv1d(1, 1, 1, bias=False)
# 自己康康打印了什么
print(model1.state_dict())
print(model2.state_dict())
'''
OrderedDict([('weight', tensor([[[0.4058]]]))])
OrderedDict([('weight', tensor([[[0.8899]]]))])
'''

sgd = optim.SGD(model2.parameters(), lr=1)
sgd = optim.SGD([{"params": model1.parameters(), "initial_lr": 100},
                 {"params": model2.parameters(), "initial_lr": 10}], lr=1000)
Torch 多模型的参数分组

在这一行

sgd = optim.SGD(model2.parameters(), lr=1)

咱们一般看别人代码,优化器的第一个参数 params 直接一个 model2.parameters()

但是实际上可以多个

sgd = optim.SGD([{"params": model1.parameters(), "initial_lr": 100},
                 {"params": model2.parameters(), "initial_lr": 10}], lr=1000)

打印一下参数看看:

>>> sgd.state_dict
<bound method Optimizer.state_dict of SGD (
Parameter Group 0
    dampening: 0
    initial_lr: 100
    lr: 1000
    momentum: 0
    nesterov: False
    weight_decay: 0

Parameter Group 1
    dampening: 0
    initial_lr: 10
    lr: 1000
    momentum: 0
    nesterov: False
    weight_decay: 0
)>

看到了吗,两组参数,SGD对其分别优化,实际上,这个字典
{"params": model1.parameters(), "initial_lr": 100}
还可以塞更多的参数,就是上边儿显示的 momentum, nesterovweight_decay等参数

我们加一个lr参数:

sgd = optim.SGD([{"params": model1.parameters(), "initial_lr": 100},
                 {"params": model2.parameters(), "initial_lr": 10, "lr":2000}], lr=1000)
>>> sgd.state_dict
<bound method Optimizer.state_dict of SGD (
Parameter Group 0
    dampening: 0
    initial_lr: 100
    lr: 1000
    momentum: 0
    nesterov: False
    weight_decay: 0

Parameter Group 1
    dampening: 0
    initial_lr: 10
    lr: 2000           # <------------------ 看到了吗这里改了
    momentum: 0
    nesterov: False
    weight_decay: 0
)>

实际上,你在SGD构造函数中填入的参数,后来会被每个字典中的参数再覆盖一次

你看,SGD构造函数的参数和 state_dict 中的参数一样
在这里插入图片描述

但是initial_lr 参数只能在字典中传入

(当然,实际lr参数一般不设置为1000多)

在实际运行过程中,可能会遇到这个Error,SGD没这个参数,这个参数要放在那个字典里

KeyError: "param 'initial_lr' is not specified in param_groups[0] when resuming an optimizer"
lr_scheduler.StepLR
torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False)
optimizer :  就是你的优化器
step_size :  每隔多少个step 开始衰减一次学习率
gamma     :  衰减比例
last_epoch:  用于指示上一个epoch是多少,比如你已经训练了100个epoch,断点续训时,这里指定100就行
verbose   :  是否会打印出你的学习率

给个例子,自己康康吧:

import torch
import torch.optim as optim 
import torch.nn as nn


model1 = nn.Conv1d(1, 1, 1, bias=False)
model2 = nn.Conv1d(1, 1, 1, bias=False)
# print(model1.state_dict())
# print(model2.state_dict())


sgd = optim.SGD(model2.parameters(), lr=1)
sgd = optim.SGD([{"params": model1.parameters(), "initial_lr": 100},
                 {"params": model2.parameters(), "initial_lr": 10, "lr":2000}], lr=1000)



sc = optim.lr_scheduler.StepLR(sgd, step_size=2, verbose=True, last_epoch=-1, gamma=0.1)
# verbose 设置为 True 就会打印出信息 "Adjusting learning rate of group XXXX"

print("没运行前的学习率: ")
print(sc.get_last_lr(), "\n")


for i in range(5):

    sgd.step()
    sc.step()
    print()
Adjusting learning rate of group 0 to 1.0000e+03.
Adjusting learning rate of group 1 to 2.0000e+03.
没运行前的学习率: 
[1000, 2000] 

Adjusting learning rate of group 0 to 1.0000e+03.
Adjusting learning rate of group 1 to 2.0000e+03.

Adjusting learning rate of group 0 to 1.0000e+02.
Adjusting learning rate of group 1 to 2.0000e+02.

Adjusting learning rate of group 0 to 1.0000e+02.
Adjusting learning rate of group 1 to 2.0000e+02.

Adjusting learning rate of group 0 to 1.0000e+01.
Adjusting learning rate of group 1 to 2.0000e+01.

Adjusting learning rate of group 0 to 1.0000e+01.
Adjusting learning rate of group 1 to 2.0000e+01.

然后 optim.lr_scheduler.StepLR 还有三个常用方法:
get_last_lrget_lrstate_dict

前俩个就是表面含义,获取上一次的学习率,和 最新的学习率

运行 get_lr 会有 warning 推荐你用这个 get_last_lr

>>> sc.state_dict()
{'step_size': 2,
 'gamma': 0.1,
 'base_lrs': [100, 10],
 'last_epoch': 9,
 '_step_count': 6,
 'verbose': True,
 '_get_lr_called_within_step': False,
 '_last_lr': [1.0, 2.0]}

这个也是打印 lr_scheduler 的一些信息

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值