《软件工程》-卷积神经网络

本文介绍了使用PyTorch实现卷积神经网络(CNN)在MNIST和CIFAR10数据集上的应用。通过对比全连接网络与CNN的效果,展示了CNN在图像识别中的优势。在MNIST数据集上,CNN相比全连接网络取得了更好的准确率。同时,通过打乱像素顺序的实验,进一步证明了CNN在局部性和平移不变性方面的特性。在CIFAR10数据集上,使用VGG16网络达到了84.92%的测试准确率。
摘要由CSDN通过智能技术生成

一.MNIST 数据集分类

深度卷积神经网络中,有如下特性

另外值得注意的是,DataLoader是一个比较重要的类,提供的常用操作有:batch_size(每个batch的大小), shuffle(是否进行随机打乱顺序的操作), num_workers(加载数据的时候使用几个子进程)

  • 很多层: compositionality
  • 卷积: locality + stationarity of images
  • 池化: Invariance of object class to translations
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torchvision import datasets, transforms
    import matplotlib.pyplot as plt
    import numpy
    
    # 一个函数,用来计算模型中有多少参数
    def get_n_params(model):
        np=0
        for p in list(model.parameters()):
            np += p.nelement()
        return np
    
    # 使用GPU训练,可以在菜单 "代码执行工具" -> "更改运行时类型" 里进行设置
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    1. 加载数据 (MNIST)

    PyTorch里包含了 MNIST, CIFAR10 等常用数据集,调用 torchvision.datasets 即可把这些数据由远程下载到本地,下面给出MNIST的使用方法:

    torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)

  • root 为数据集下载到本地后的根目录,包括 training.pt 和 test.pt 文件
  • train,如果设置为True,从training.pt创建数据集,否则从test.pt创建。
  • download,如果设置为True, 从互联网下载数据并放到root文件夹下
  • transform, 一种函数或变换,输入PIL图片,返回变换之后的数据。
  • target_transform 一种函数或变换,输入目标,进行变换。
    input_size  = 28*28   # MNIST上的图像尺寸是 28x28
    output_size = 10      # 类别为 0 到 9 的数字,因此为十类
    
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.1307,), (0.3081,))])),
        batch_size=64, shuffle=True)
    
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transforms.Compose([
                 transforms.ToTensor(),
                 transforms.Normalize((0.1307,), (0.3081,))])),
        batch_size=1000, shuffle=True)

    运行结果为:

    Downloadinghttp://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
    Downloadinghttp://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to
    ./data/MNIST/raw/train-images-idx3-ubyte.gz Extracting
    ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

    Downloadinghttp://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
    Downloadinghttp://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to
    ./data/MNIST/raw/train-labels-idx1-ubyte.gz Extracting
    ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

    Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
    Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
    to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz Extracting
    ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

    Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
    Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
    to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz Extracting
    ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

    /usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:498:
    UserWarning: The given NumPy array is not writeable, and PyTorch does
    not support non-writeable tensors. This means you can write to the
    underlying (supposedly non-writeable) NumPy array using the tensor.
    You may want to copy the array to protect its data or make it
    writeable before converting it to a tensor. This type of warning will
    be suppressed for the rest of this program. (Triggered internally at
    /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.) return
    torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

      显示数据集中的部分图像

plt.figure(figsize=(8, 5))
for i in range(20):
    plt.subplot(4, 5, i + 1)
    image, _ = train_loader.dataset.__getitem__(i)
    plt.imshow(image.squeeze().numpy(),'gray')
    plt.axis('off');

运行结果为:

2. 创建网络

定义网络时,需要继承nn.Module,并实现它的forward方法,把网络中具有可学习参数的层放在构造函数init中。

只要在nn.Module的子类中定义了forward函数,backward函数就会自动被实现(利用autograd)。

class FC2Layer(nn.Module):
    def __init__(self, input_size, n_hidden, output_size):
        # nn.Module子类的函数必须在构造函数中执行父类的构造函数
        # 下式等价于nn.Module.__init__(self)        
        super(FC2Layer, self).__init__()
        self.input_size = input_size
        # 这里直接用 Sequential 就定义了网络,注意要和下面 CNN 的代码区分开
        self.network = nn.Sequential(
            nn.Linear(input_size, n_hidden), 
            nn.ReLU(), 
            nn.Linear(n_hidden, n_hidden), 
            nn.ReLU(), 
            nn.Linear(n_hidden, output_size), 
            nn.LogSoftmax(dim=1)
        )
    def forward(self, x):
        # view一般出现在model类的forward函数中,用于改变输入或输出的形状
        # x.view(-1, self.input_size) 的意思是多维的数据展成二维
        # 代码指定二维数据的列数为 input_size=784,行数 -1 表示我们不想算,电脑会自己计算对应的数字
        # 在 DataLoader 部分,我们可以看到 batch_size 是64,所以得到 x 的行数是64
        # 大家可以加一行代码:print(x.cpu().numpy().shape)
        # 训练过程中,就会看到 (64, 784) 的输出,和我们的预期是一致的

        # forward 函数的作用是,指定网络的运行过程,这个全连接网络可能看不啥意义,
        # 下面的CNN网络可以看出 forward 的作用。
        x = x.view(-1, self.input_size)
        return self.network(x)
    


class CNN(nn.Module):
    def __init__(self, input_size, n_feature, output_size):
        # 执行父类的构造函数,所有的网络都要这么写
        super(CNN, self).__init__()
        # 下面是网络里典型结构的一些定义,一般就是卷积和全连接
        # 池化、ReLU一类的不用在这里定义
        self.n_feature = n_feature
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=n_feature, kernel_size=5)
        self.conv2 = nn.Conv2d(n_feature, n_feature, kernel_size=5)
        self.fc1 = nn.Linear(n_feature*4*4, 50)
        self.fc2 = nn.Linear(50, 10)    
    
    # 下面的 forward 函数,定义了网络的结构,按照一定顺序,把上面构建的一些结构组织起来
    # 意思就是,conv1, conv2 等等的,可以多次重用
    def forward(self, x, verbose=False):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.ma
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值