【PyTorch教程】P24 优化器

本文详细介绍了PyTorch中的优化器如何利用反向传播调整模型参数,特别关注了optim.zero_grad()和反向传播的过程。通过示例展示了SGD优化器的使用,并在训练过程中观察梯度变化。代码示例中包含了模型定义、损失函数、优化器配置以及训练循环,用于理解优化器在深度学习模型训练中的作用。
摘要由CSDN通过智能技术生成

P24 优化器

  • 优化器利用反向传播,对参数进行调整。

  • 官网中的位置:介绍了优化器的构造过程:
    在这里插入图片描述

  • optim里面的算法理论很深入,如果不深究,只要parameter和lr需要设置,其他的都是默认参数:
    在这里插入图片描述

  • 上图关注重点在optim.zero_grad()和后面的两行,用调试功能,查看梯度是否有数值。

  • 下图是查看的位置:注意每次运行一步之后的梯度是否有数值:
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

  • Optim.Step()之后的结果:
    在这里插入图片描述

  • 以上就是一轮学习(one epoch)的过程,再在外面嵌套一个20轮的循环:
    在这里插入图片描述

  • 再设置一个running_loss:把每一轮训练的损失构造出来:
    在这里插入图片描述

可以运行的代码

# -*- coding: utf-8 -*-

'''
import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("../dataset", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
scheduler = StepLR(optim, step_size=5, gamma=0.1)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()   # 每次循环,都要把梯度清零
        result_loss.backward()
        scheduler.step()
        running_loss = running_loss + result_loss
    print(running_loss)



# -*- coding: utf-8 -*-

'''

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Linear
from torch.nn.modules.flatten import Flatten
from torch.utils.data import DataLoader


dataset = torchvision.datasets.CIFAR10("../dataset", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=64)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = Conv2d(3, 32, 5, padding=2)
        self.maxpool1 = MaxPool2d(2)
        self.conv2 = Conv2d(32, 32, 5, padding=2)
        self.maxpool2 = MaxPool2d(2)
        self.conv3 = Conv2d(32, 64, 5, padding=2)
        self.maxpool3 = MaxPool2d(2)
        self.flatten = Flatten()
        self.linear1 = Linear(64 * 4 * 4, 64)
        self.linear2 = Linear(64, 10)

    def forward(self, m):
        m = self.conv1(m)
        m = self.maxpool1(m)
        m = self.conv2(m)
        m = self.maxpool2(m)
        m = self.conv3(m)
        m = self.maxpool3(m)
        m = self.flatten(m)
        m = self.linear1(m)
        m = self.linear2(m)
        return m


loss = nn.CrossEntropyLoss()  # 定义损失函数
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        # print(outputs)
        # print(targets)
        result_loss = loss(outputs, targets)  # 调用损失函数
        optim.zero_grad()
        result_loss.backward()  # 反向传播, 这里要注意不能使用定义损失函数那里的 loss,而要使用 调用损失函数之后的 result_loss
        optim.step()
        # print("OK")    # 这部分,在debug中可以看到 grad 通过反向传播之后,才有值,debug修好了之后,再来看这里
        # print(result_loss)
        running_loss = running_loss + result_loss
    print(running_loss)

完整目录

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值