CNN-LeNet5_完整项目_CodingPark编程公园

CIFAR-10 数据集简介

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。 CIFAR-10 的图片样例如图所示。
下面这幅图就是列举了10各类,每一类展示了随机的10张图片:

在这里插入图片描述
与 MNIST 数据集中目比, CIFAR-10 具有以下不同点:
• CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。
• CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。
• 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。 直接的线性模型如 Softmax 在 CIFAR-10 上表现得很差。

在这里插入图片描述
知识城邦🏯
#self.criteon = nn.MSELoss() # 处理逼近问题
#self.criteon = nn.CrossEntropyLoss() # 处理分类问题

nn后面的类都具有大写的元素

nn➕后面内容 是类,类就需要初始化
例如:self.fc_unit = nn.Sequential(
nn.Linear(3255, 32),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(32, 10)
)

之后
在forward时调用 logits = self.fc_unit(x)
⚠️ logits = self.fc_unit(x) 中 调用fc_unit实例时 ➕ (x)
你并没有调用fc_unit.forward(x)
这是因为 class Lenet5(nn.Module)中的nn.Module 有内部调用
logits = self.fc_unit(x) 相当于 logits = self.fc_unit.forward(x)

F 后面的方法都是小写的
就是函数,直接可以用

argmax: 要最大值的 索引
max: 要最大值

—————————utils.py—工具文件—————————

import torch
from matplotlib import pyplot as plt        # 绘图


# makecurve
# to show loss picture
def plot_curve(data):
    fig = plt.figure()  # 设置绘图区域的大小和像素
    plt.plot(range(len(data)),data,color = 'blue')  # 将实际值的折线设置为蓝色
    plt.legend(['value'],loc = 'upper right')   # 显示图例的位置,自适应方式
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


# draw image
# designed to show picture meterial
def plot_img(img, label, name):
    fig = plt.figure()  # plt.figure()用来画图,create a figure;自定义画布大小,表示figure 的大小为宽、长(单位为inch)
    for i in range(6):
        plt.subplot(2,3,i + 1)  # 表示整个figure分为2行3列
        plt.tight_layout()
        plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

#one-hot
def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

—————————lenet5.py—模型文件—————————
Lenet5网络结构 创建

import torch
from torch import nn
from torch.nn import functional as F


'''
step2. Net Creat   网络创建

'''

class Lenet5(nn.Module):


    def __init__(self):
        super(Lenet5, self).__init__()

        self.conv_unit = nn.Sequential(                                                Sequential不需要给每个层编号,同样也就没有赋值 ,因为总的来说它时一个整体;
            # x: [b, 3, 32, 32]=> [b, 6, size]
            nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=0),       # in_channels=3代表 3个通道in 其实就是RGB这三个
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),          # kernel_size=2 kernel一次看一个长宽各2的窗口

            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=0),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
        )

        # flatten打平         从16*5*5 -> 10

        # 全连接层
        self.fc_unit = nn.Sequential(
            nn.Linear(32*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10)
        )


        # [b, 3, 32, 32]                            测试1,看conv_unit后的out_channels数
        # tmp =  torch.randn(2, 3, 32, 32)
        # out = self.conv_unit(tmp)
        # # [b, 16, 5, 5]
        # print('conv out: ',out.shape)

        # use Cross Entropy Loss
        # self.criteon = nn.MSELoss()               # 处理逼近问题
        # self.criteon = nn.CrossEntropyLoss()      # 处理分类问题



    def forward(self, x):                    # 前向路径   #param(参数) x:[b, 3, 32, 32]
        batchsz = x.size(0)                  #param(参数) x:[b, 3, 32, 32] 其中的 b

        # [b, 3, 32, 32] => [b, 16, 5, 5]
        x = self.conv_unit(x)

        # flatten打平         从16*5*5 -> 10
        x = x.view(batchsz, 32*5*5)

        # [b, 16*5*5] => [b, 10]
        logits = self.fc_unit(x)        # logits: 网络最后一般送入softmax 那么 softmax前的变量 统称 logits

        return logits





# def main():
#     net = Lenet5()
#     tmp = torch.randn(2, 3, 32, 32)
#     out = net(tmp)
#     # [b, 16, 5, 5]
#     print('lenet5 out: ', out.shape)
#
#
# if __name__ == '__main__':
#     main()

—————————main.py—主文件—————————

import torch

from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn, optim
from lenet5 import Lenet5

import torchvision                              # vision 视觉
from matplotlib import pyplot as plt            # 绘图
from utils import plot_img, plot_curve, one_hot # 工具包




def main():
    '''
    step1. load dataset   加载数据集

    '''
    batchsz = 128
    # 训练集
    cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]), download=True)

    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)

    # 测试集
    cifar_test = datasets.CIFAR10('cifar', train=False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]), download=True)

    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=False)



    # 尝试显示加载项
    x, label = next(iter(cifar_train))
    print(x.shape, label.shape, x.min(), x.max())
    plot_img(x, label, 'TEAM-AG_cifar10')

    '''
    step3.  Train  训练
    
    '''
    print('----------Train  训练----------')
    model = Lenet5()
    criteon = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)     # model.parameters()优化网络模型中的参数
    train_loss = []        # train_loss记录
    print(model)           # 打印这个类的结构!

    for epoch in range(50):
        model.train()           # 网络模型train模式
        for batch_idx, (x, label) in enumerate(cifar_train):
            # x [b, 3, 32, 32]
            # label [b]

            # 走网络
            logits = model(x)

            # 走loss
            loss = criteon(logits, label)

            # 走素质三连
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())

            if batch_idx % 10 == 0:
                print(epoch, batch_idx, loss.item())

    plot_curve(train_loss)  # 画 loss 曲线

    '''
    step4.  Test  测试

    '''
    print('----------Test  测试----------')
    model.eval()  # 网络模型eval模式
    #test不计算梯度 所以包在 with torch.no_grad(): 函数里
    with torch.no_grad():
        tottal_num = 0
        tottal_correct = 0
        for x, label in cifar_test:
            # 走网络
            logits = model(x)

            #取logits最大的元素, 作为pred(predict)
            pred = logits.argmax(dim=1)
            tottal_num = len(cifar_test.dataset)
            print('tottal_num : ', tottal_num)
            tottal_correct += torch.eq(pred, label).sum().float().item()
            print('tottal_correct : ', tottal_correct)

        acc = tottal_correct/tottal_num
        print('准确率: ',acc)

        x, label = next(iter(cifar_test))  # 取一个batch,查看预测结果
        out = model(x)
        pred = out.argmax(dim=1)  # 取得[b, 10]的10个值的最大值所在位置的索引
        plot_img(x, pred, 'TEAM-AG_cifar10')




if __name__ == '__main__':
    main()

在这里插入图片描述

评论将由博主筛选后显示,对所有人可见 | 还能输入1000个字符
©️2020 CSDN 皮肤主题: 鲸 设计师: meimeiellie 返回首页
实付0元
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值