pytorch优化器(optimizer)中params参数详细介绍

        这里先给出使用的一个小型网络(自己瞎定义的一个网络),后面使用的model就是这里定义的一个小型的网络:

# 定义网络
class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 8, 3),
            nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Linear(288, 5)
        )

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        print(x.shape)
        x = self.fc(x)
        return x

# 实例化网络
model = Test()

        Pytorch的优化器中都有一个参数:params。在这里就详细描述一下params。

        params就是网络中需要优化的网络参数,在这里需要注意的是传入的网络参数必须使可以迭代的对象。

        (1)如果我们只是想优化一个网络,那么我们就把一整个网络看做一个param_groups,params参数的赋值为model.parameters()。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
print(len(optimizer.param_groups))

        因为这里我们把整个网络看做一个param_groups,所以这里我们的执行结果为:1。

        (2)如果我们想要的是同时优化多个网络的参数,这里介绍两种方法:

                a.我们将多个网络的参数合并到一起,当成一个网络的参数来进行优化一般的赋值方式为

model_1 = Test()
model_2 = Test()
optimizer = torch.optim.Adam([*model_1.parameters(), *model_2.parameters()], lr=0.01)
print(len(optimizer.param_groups))

        代码的执行结果为:1。这样我们就可以把多个网络参数合并成一个网络参数进行优化。

                b.多个网络分开优化,并且可以使用各不相同的学习率,赋值方式为:

model_1 = Test()
model_2 = Test()
optimizer = torch.optim.Adam([
    {'params':model_1.parameters()},
    {'params':model_2.parameters(), 'lr': 0.2}
], lr=0.01)
print("优化器里有多少个param_groups: ", len(optimizer.param_groups))
print("网络1的学习率为: ", optimizer.param_groups[0]['lr'])
print("网络2的学习率为: ", optimizer.param_groups[1]['lr'])

          执行结果为:

优化器里有多少个param_groups:  2
网络1的学习率为:  0.01
网络2的学习率为:  0.2

        从这个结果中我们可以看出,每个param_groups中可以单独定义学习率lr,如果没有指定的话则默认采取后面的学习率。

  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值