PyTorch学习笔记(四)优化器

本文介绍了PyTorch中的优化器,包括SGD、Adagrad、RMSprop、Adadelta、Adam、AdamW、SparseAdam、Adamax、ASGD、LBFGS和Rprop。通过示例详细讲解了SGD的属性、方法及权重衰减,并概述了其他优化器的工作原理。
摘要由CSDN通过智能技术生成

Environment

  • OS: macOS Mojave
  • Python version: 3.7
  • PyTorch version: 1.4.0
  • IDE: PyCharm


0. 写在前面

PyTorch 在 torch.optim 中提供了十来种优化器,它们的基类为 torch.optim.Optimizer。这些优化器用于管理参数(包括超参数和可学习参数),能够更新可学习参数。

参考 PyTorch 官方文档

1. SGD

以一个简单的网络 LeNet 和 SGD 优化器为例,搞一下 optimizer 常用的实例属性和方法

import torch
from torch.nn import Module, Sequential, Conv2d, ReLU, MaxPool2d, Linear


class LeNet5(Module):
    def __init__(self, num_classes=10):
        super(LeNet5, self).__init__()

        self.features = Sequential(
            Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2),
            ReLU(),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(6, 16, 5),
            ReLU(),
            MaxPool2d(2, 2)
        )

        self.classifier = Sequential(
            Linear(in_features=16 * 5 * 5, out_features=120),
            ReLU(),
            Linear(120, 84),
            ReLU(),
            Linear(84, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        return x


lenet = LeNet5()

torch.optim.SGD 类,实现随机梯度下降,可带有动量。

from torch.optim import SGD

optimizer = SGD(
    params=lenet.parameters(),  # 管理的参数组
    lr=0.1,  # 初始学习率
    momentum=.9,  # 动量系数,\beta
    dampening=0,  # 动量抑制。尚未用过...
    weight_decay=0,  # L2 正则化系数
    nesterov=False  # 是否采用 Nesterov 提出的 NAG 梯度下降方法。通常就为默认 False
)

1.1 optimizer 常用的实例属性

  • optimizer.defaults 存储超参数
print(optimizer.defaults)
# {'lr': 0.3, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}

注意,如果使用学习率调整策略后,optimizer.defaults 中的学习率并不会变化,而是 optimizer.param_groups 中的发生变化。


  • optimizer.param_groups 分组存储可学习参数和超参数

数据类型为列表

print(type(optimizer.param_groups))  # <class 'list'>

有几个 group,列表就有几项

print(len(optimizer.param_groups))  # 1

其中每一项是一个字典,键分别为 ‘params’ 和超参数名

print(optimizer.param_groups[0].keys())
# dict_keys(['params', 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'])

学习参数的键 params 对应的值为一个列表

print(type(optimizer.param_groups[0]['params'])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值