Pytorch图像分类:01使用PyTorch搭建LeNet模型

简介】:基于cifar10使用PyTorch搭建LeNet模型进行图片分类
【参考】:2.1 pytorch官方demo(Lenet)_哔哩哔哩_bilibili
【代码完整版】:01LeNet (github.com)

注:本人还在学习初期,此文是为了梳理自己所学整理的,有些说法是自己的理解,不一定对,如有差错,请批评指正!


1.搭建LeNet模型

说明:常见的一个完整的深度学习项目处理步骤为:数据处理-搭建模型-设置损失函数和优化器-模型训练-模型评估,其实划分粗略一些,可以是模型训练前-模型训练-模型训练后,训练前是做一些准备工作,如数据准备、模型构建,准备前的这些步骤可以换顺序,反正只要在模型训练之前准备好就行。此项目是先构建模型,再进行的后续步骤,所以也没问题。但如果想先处理数据,再构建模型,也可以,具体可以看下面步骤2。
    首先搭建LeNet网络模型,以下是LeNet的网络结构:
LeNet
在这里插入图片描述

新建一个文件model.py,定义LeNet类,它是继承于torch.nn.Module类的,在类里面要写两个函数,分别是__init__()forword(),init()里面包含的是构建模型的基本网络结构,如卷积层、池化层、全连接层,forward()是前向传播函数,通过init()里面定义的基本网络结构,将模型串(搭建)出来,代码如下:

import torch
from torch import nn
from torch.nn import  functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        # self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)与下面等价
        self.conv1=nn.Conv2d(3,6,5)
        # self.pool1=nn.MaxPool2d(kernel_size=2,stride=2)与下面等价
        self.pool1=nn.MaxPool2d(2,2)
        self.conv2=nn.Conv2d(6,16,5)
        self.pool2=nn.MaxPool2d(2,2)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)

    def forwar(self,x):
        x=F.relu(self.conv1(x))#input(3,32,32),output(6,28,28)
        x=self.pool1(x)        #output(6,14,14)
        x=F.relu(self.conv2(x))#output(16,10,10)
        x=self.pool2(x)        #output(16,5,5)
        x=x.view(-1,16*5*5)    #output(16*5*5)
        x=F.relu(self.fc1(x))  #output(120)
        x=F.relu(self.fc2(x))  #output(84)
        x=self.fc3(x)          #output(10)
        return x

打印model的网络结构:

input=torch.rand([32,3,32,32])#batch-size,channel,width,height
model=LeNet()
print(model)

在这里插入图片描述
output是怎么算出来的呢?
以下是公式:
在这里插入图片描述
以conv1为例,输入3 * 32 * 32,kernel-size=5,padding=0,stride默认为1,w=h=32,w1=(32-5)/1+1=28

2. 训练

2.1数据预处理

数据集介绍:
这里使用PyTorch官方提供的cifar10数据集,只需调用函数即可下载
数据集链接:Training a Classifier — PyTorch Tutorials 2.4.0+cu124 documentation
cifar10

1)定义数据预处理工具包transform

在数据进入模型之前需要进行一些预处理,例如数据中心化(仅减均值),数据标准化(减均值,再除以标准差),随机裁剪,旋转一定角度,镜像等一系列操作, 在 PyTorch 中,这些数据增强方法被放在了torchvision/transforms/transforms.py文件中。

可以把transforms看作是一个数据预处理工具包,它里面用于处理的每个函数可以看作是单独的一个工具,我们可以用transforms.Compose将多个工具组合起来形成自己的工具包。

import torch
import torchvision
from torchvision import transforms

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0,5),(0.5,0.5,0.5))
])

现在这个工具包里面有两个处理工具,一个是ToTensor(),一个是Normarlize(),点击跳转到ToTensor()的源码部分,可以看到它的功能是将PIL Image和numpy类型的数据转换为Tensor类型,转换之前,数据的形状为H×W×C,并且每个值处于0~255之间,转换为Tensor类型后,数据的形状为C×H×W,每个值的范围也变成了[0,1]。
transforms-totensor()
Normalize()的作用是对数据按通道进行标准化,即将输入数据减去均值,再除以方差,经它处理后,数据的值的范围将由[0,1]变为[-1, 1]。
Normalize

2)读取数据,进行数据预处理

#50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

通过CIFAR10导入了训练集,然后将训练集通过transformer的预处理函数对每个图像进行预处理。首次训练需要将download设为True,它会将数据集下载到设定的路径下。
CIFAR10来自torchvision.datasets,这里面除了CIFAR10,还有ImageNet, MNIST等数据集。

3)加载数据

trainloader=torch.utils.data.DataLoader(trainset,batch_size=16,shuffle=True,num_workers=0)

将训练集导入进来,分成一个批次一个批次地进行训练,将训练集打乱,分成16个批次(每个批次的数据是随机提取出来的),工作进程设置为0表示只会有一个主进程。

测试集图片直接按上面的思路写:

#10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=False, transform=transform)
testloader=torch.utils.data.DataLoader(testset,batch_size=10000,shuffle=True,num_workers=0)

2.2定义模型、损失函数、优化器

直接导入刚才搭建的模型

from model import LeNet

model=LeNet()

定义损失函数和优化器,这里使用交叉熵损失函数、Adam优化器

from torch import nn

loss_function=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)

2.3模型训练

epoches=5
    for epoch in range(epoches):
        running_loss=0.0
        for step,data in enumerate(trainloader):
            # get the inputs; data is a list of [inputs, labels]
            inputs,labels=data
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs=model(inputs)
            loss=loss_function(outputs,labels)
            loss.backward()
            optimizer.step()#参数更新

            #print statistics
            running_loss+=loss.item()
            if step%500 == 499:
                with torch.no_grad():
                    outputs=model(test_image)#[batch,10]
                    predict_y=torch.max(outputs,dim=1)[1]#找到预测类别最可能值的下标
                    accuracy=torch.eq(predict_y,test_label).sum().item()/test_label.size(0)

                    print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %
                          (epoch + 1, step + 1, running_loss / 500, accuracy))
                    running_loss = 0.0
    print('Finished Training')
 #保存模型
 save_path = './LeNet.pth'
 torch.save(model.state_dict(), save_path)

📌📌📌模型训练关键步骤:

  • optimizer.zero_grad()将每个参数的梯度值设为0,这样上一次的梯度记录被清空
  • outputs=model(inputs) loss=loss_function(outputs,labels) 前向传播计算损失
  • loss.backward()反向传播
  • optimizer.step()更新参数

核心是通过计算损失来更新参数

下半部分用于评估每一轮次的效果,它的核心思想是计算损失和准确率

3.预测

从网上随便下载一张飞机图片,用刚训练好的模型对它进行预测
在这里插入图片描述

新建一个predict.py文件
1)数据预处理

import torch
from model import LeNet
from torchvision import transforms
#定义数据预处理工具包
transform=transforms.Compose([
    transforms.Resize([32,32]),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
from PIL import Image
#准备数据
img=Image.open('airplane.jpg')
img=transform(img)#[C,H,W]
img=torch.unsqueeze(img,dim=0)#增加一个维度 [N,C,H,W]

2)模型

#定义并加载模型
model=LeNet()
model.load_state_dict(torch.load('LeNet.pth'))

3)预测

classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
with torch.no_grad():
    output=model(img)
    predict=torch.max(output,dim=1)[1].numpy()
    print(predict) #[0]
print(classes[(predict)])#classes[0]='airplane'

输出>:airplane

  • 28
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值