Pytorch 学习 (三) 案例CIFAR10 图片分类

目录

 

数据库介绍

使用以下  torchvision 加载和标准化CIFAR10训练和测试数据集

附录: CNN卷积神经网络的卷积层、池化层的输出维度计算公式


 

数据库介绍

在本次案例中 使用 CIFAR10 数据集。它具有以下类别:“飞机”,“汽车”,“鸟”,“猫”,“鹿”,“狗”,“青蛙”,“马”,“船”,“卡车”。CIFAR-10中的图像尺寸为3x32x32,即尺寸为32x32像素的3通道彩色图像。

cifar10

基本顺序:

  1. 使用以下  torchvision 加载和标准化CIFAR10训练和测试数据集
  2. 定义卷积神经网络
  3. 定义损失函数
  4. 根据训练数据训练网络
  5. 在测试数据上测试网络

使用以下  torchvision 加载和标准化CIFAR10训练和测试数据集

# -*- coding: utf-8 -*-
# @Time    : 2021/1/13 11:56
# @Author  : wwq_biubiu!!
# @FileName: chapter2_5_Cifar10.py
# @Software: PyCharm
import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

import matplotlib.pyplot as plt
import numpy as np

# functions to show an image

def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    # np.transpose 转换 维度
    # 因为在plt.imshow在现实的时候输入的是(imagesize,imagesize,channels)
    # 这两者的格式不一致,我们需要调用一次np.transpose函数
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

    plt.show()

# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images, labels)
print(images.shape, labels.shape)
# transpose
# show images
#imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
#
# 定义 卷积神经网络
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5) 
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x= x.view(-1, 16* 5* 5)
        x= F.relu(self.fc1(x))
        x= F.relu(self.fc2(x))
        x= self.fc3(x)
        return x
net = Net()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print (device)
net.to(device)
import torch.optim as optim


# 开始训练
losses = []
acces = []
eval_losses = []
eval_acces = []

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr = 0.0001 , momentum = 0.9)
for epoch in range(2):
    run_loss = 0.0
    train_acc = 0
    total = 0
    for i, data in enumerate(trainloader, 0):

        #inputs, lables = data
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()
        outputs =net(inputs)
        loss = criterion (outputs, labels)
        loss.backward()
        optimizer.step()
        #torch.max()[0], 只返回最大值的每个数
        #troch.max()[1], 只返回最大值的每个索引

        _, predicted = torch.max(outputs.data, 1)

        total = labels.size(0)
        num_correct = (predicted == labels).sum().item()

        acc = num_correct / total
        train_acc += acc
        run_loss = run_loss+ loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, run_loss / 2000))
            print('[%d, %5d] acc: %.3f' %
                  (epoch + 1, i + 1, train_acc / 2000))
            run_loss = 0.0
            train_acc = 0
            total =0
    acces.append(train_acc / 2000)
    losses.append(run_loss / 2000)
print('Finished Training')

# 保存 模型
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
net = Net()

#加载模型
net.load_state_dict(torch.load(PATH))
outputs = net(images)
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))
#预测
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))
#预测 
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

output:

 

D:\anaconda3\envs\GNN\python.exe E:/Work2020/2020.12.02-deeplearning/Numpy/chapter2_5_Cifar10.py
Files already downloaded and verified
Files already downloaded and verified
tensor([[[[ 0.7020,  0.7020,  0.6941,  ...,  0.6314,  0.6235,  0.6235],
          [ 0.7882,  0.8039,  0.7725,  ...,  0.6314,  0.6314,  0.6314],
          [ 0.8510,  0.8431,  0.7961,  ...,  0.6235,  0.6235,  0.6314],
          ...,
          [ 0.2784,  0.1294, -0.2471,  ..., -0.5294, -0.4510, -0.7098],
          [ 0.2549,  0.2784,  0.0667,  ..., -0.4980, -0.3647, -0.4196],
          [ 0.1686,  0.2235,  0.2627,  ..., -0.5608, -0.2471, -0.3882]],

         [[ 0.7490,  0.7333,  0.7255,  ...,  0.7412,  0.7412,  0.7412],
          [ 0.7804,  0.7961,  0.7647,  ...,  0.7412,  0.7490,  0.7490],
          [ 0.8353,  0.8275,  0.7804,  ...,  0.7412,  0.7490,  0.7490],
          ...,
          [ 0.2078,  0.0510, -0.3255,  ..., -0.5843, -0.5216, -0.7882],
          [ 0.1922,  0.2000, -0.0196,  ..., -0.5843, -0.4431, -0.4980],
          [ 0.0980,  0.1373,  0.1686,  ..., -0.6392, -0.3255, -0.4667]],

         [[ 0.8824,  0.9059,  0.9216,  ...,  0.8824,  0.9137,  0.9294],
          [ 0.8980,  0.9373,  0.9294,  ...,  0.8980,  0.9137,  0.9373],
          [ 0.9373,  0.9608,  0.9373,  ...,  0.8980,  0.9137,  0.9373],
          ...,
          [ 0.5608,  0.3333, -0.1373,  ..., -0.5059, -0.4510, -0.7255],
          [ 0.5765,  0.5216,  0.1686,  ..., -0.4902, -0.3569, -0.4039],
          [ 0.4824,  0.4745,  0.3961,  ..., -0.5529, -0.2392, -0.3804]]],


        [[[ 0.2471,  0.2314,  0.2784,  ...,  0.7569,  0.7412,  0.7098],
          [ 0.2314,  0.2078,  0.2549,  ...,  0.7333,  0.7176,  0.6863],
          [ 0.2549,  0.2392,  0.2941,  ...,  0.7412,  0.7176,  0.6863],
          ...,
          [ 0.7490,  0.7333,  0.7412,  ..., -0.0353, -0.0824, -0.1922],
          [ 0.8039,  0.7961,  0.8353,  ...,  0.3647,  0.0902, -0.2392],
          [ 0.8431,  0.8510,  0.8510,  ...,  0.2549, -0.0275, -0.3569]],

         [[ 0.2392,  0.2235,  0.2706,  ...,  0.7569,  0.7412,  0.7098],
          [ 0.2235,  0.2000,  0.2471,  ...,  0.7333,  0.7176,  0.6863],
          [ 0.2471,  0.2314,  0.2863,  ...,  0.7412,  0.7176,  0.6863],
          ...,
          [ 0.7412,  0.7333,  0.7412,  ..., -0.1843, -0.2863, -0.4039],
          [ 0.8118,  0.7961,  0.8353,  ...,  0.2314, -0.0745, -0.3961],
          [ 0.8667,  0.8431,  0.8196,  ...,  0.1059, -0.2000, -0.5059]],

         [[ 0.2078,  0.1922,  0.2392,  ...,  0.7412,  0.7255,  0.6941],
          [ 0.1922,  0.1686,  0.2157,  ...,  0.7176,  0.7020,  0.6706],
          [ 0.2157,  0.2000,  0.2549,  ...,  0.7255,  0.7020,  0.6706],
          ...,
          [ 0.7333,  0.7333,  0.7569,  ..., -0.2941, -0.3961, -0.5137],
          [ 0.8118,  0.7961,  0.8275,  ...,  0.1294, -0.1529, -0.4667],
          [ 0.8824,  0.8353,  0.7961,  ..., -0.0196, -0.2941, -0.5765]]],


        [[[ 0.7647,  0.7412,  0.7412,  ...,  0.8431,  0.8353,  0.8196],
          [ 0.7882,  0.7647,  0.7725,  ...,  0.8667,  0.8588,  0.8431],
          [ 0.8039,  0.7882,  0.7961,  ...,  0.8588,  0.8588,  0.8431],
          ...,
          [ 0.1373,  0.4667,  0.7255,  ...,  0.1059,  0.0745,  0.0431],
          [ 0.1686,  0.4118,  0.6314,  ...,  0.0824,  0.0745,  0.0196],
          [ 0.0902,  0.2627,  0.4431,  ..., -0.0667, -0.0824, -0.0902]],

         [[ 0.8196,  0.7961,  0.7961,  ...,  0.8431,  0.8353,  0.8196],
          [ 0.8353,  0.8118,  0.8275,  ...,  0.8667,  0.8588,  0.8431],
          [ 0.8353,  0.8118,  0.8196,  ...,  0.8588,  0.8588,  0.8431],
          ...,
          [ 0.1922,  0.4196,  0.6157,  ...,  0.1529,  0.1216,  0.0902],
          [ 0.2627,  0.4039,  0.5529,  ...,  0.1294,  0.1137,  0.0667],
          [ 0.1922,  0.3176,  0.4431,  ..., -0.0196, -0.0510, -0.0588]],

         [[ 0.8824,  0.8667,  0.8667,  ...,  0.9059,  0.8980,  0.8824],
          [ 0.8902,  0.8667,  0.8824,  ...,  0.9294,  0.9216,  0.9059],
          [ 0.8824,  0.8588,  0.8667,  ...,  0.9216,  0.9216,  0.9059],
          ...,
          [ 0.3647,  0.5294,  0.6157,  ...,  0.3490,  0.3255,  0.2941],
          [ 0.4275,  0.5216,  0.5765,  ...,  0.3255,  0.3176,  0.2706],
          [ 0.3569,  0.4275,  0.4667,  ...,  0.1529,  0.1373,  0.1294]]],


        [[[-0.5137, -0.2471, -0.1059,  ...,  0.3961, -0.3569, -0.5059],
          [-0.1843, -0.4431, -0.6706,  ...,  0.0039, -0.5059, -0.5922],
          [-0.1373, -0.4353, -0.6078,  ..., -0.3961, -0.4196, -0.5373],
          ...,
          [ 0.3882,  0.3725,  0.3725,  ...,  0.3333,  0.3020,  0.2471],
          [ 0.3647,  0.3882,  0.4196,  ...,  0.3725,  0.2627,  0.1608],
          [ 0.4588,  0.4824,  0.4824,  ...,  0.2235,  0.1451,  0.0980]],

         [[-0.6314, -0.3569, -0.1294,  ...,  0.1843, -0.5922, -0.6863],
          [-0.3412, -0.5843, -0.7176,  ..., -0.1843, -0.7961, -0.8431],
          [-0.3255, -0.5922, -0.6706,  ..., -0.5843, -0.7490, -0.8196],
          ...,
          [ 0.1686,  0.1529,  0.1529,  ...,  0.0275, -0.0275, -0.0980],
          [ 0.1451,  0.1686,  0.2000,  ...,  0.0588, -0.0667, -0.1922],
          [ 0.2392,  0.2627,  0.2627,  ..., -0.0824, -0.1765, -0.2471]],

         [[-0.4824, -0.1137,  0.0980,  ...,  0.1765, -0.4980, -0.5922],
          [-0.2235, -0.3725, -0.4980,  ..., -0.2000, -0.7412, -0.7647],
          [-0.2549, -0.4275, -0.4745,  ..., -0.5686, -0.7020, -0.7412],
          ...,
          [ 0.0745,  0.0588,  0.0588,  ..., -0.1059, -0.1529, -0.2235],
          [ 0.0510,  0.0745,  0.1059,  ..., -0.0745, -0.1922, -0.3098],
          [ 0.1451,  0.1686,  0.1686,  ..., -0.2157, -0.3020, -0.3647]]]]) tensor([9, 2, 8, 7])
torch.Size([4, 3, 32, 32]) torch.Size([4])
truck  bird  ship horse
cuda:0
[1,  2000] loss: 2.303
[1,  2000] acc: 0.104
[1,  4000] loss: 2.302
[1,  4000] acc: 0.103
[1,  6000] loss: 2.298
[1,  6000] acc: 0.127
[1,  8000] loss: 2.289
[1,  8000] acc: 0.147
[1, 10000] loss: 2.254
[1, 10000] acc: 0.182
[1, 12000] loss: 2.178
[1, 12000] acc: 0.215
[2,  2000] loss: 2.070
[2,  2000] acc: 0.241
[2,  4000] loss: 2.000
[2,  4000] acc: 0.265
[2,  6000] loss: 1.951
[2,  6000] acc: 0.282
[2,  8000] loss: 1.896
[2,  8000] acc: 0.311
[2, 10000] loss: 1.862
[2, 10000] acc: 0.323
[2, 12000] loss: 1.825
[2, 12000] acc: 0.336
Finished Training
Predicted:  truck   car plane horse
Accuracy of the network on the 10000 test images: 36 %
Accuracy of plane : 56 %
Accuracy of   car : 51 %
Accuracy of  bird : 12 %
Accuracy of   cat : 17 %
Accuracy of  deer : 32 %
Accuracy of   dog : 36 %
Accuracy of  frog : 51 %
Accuracy of horse : 41 %
Accuracy of  ship : 13 %
Accuracy of truck : 51 %

Process finished with exit code 0
 

附录: CNN卷积神经网络的卷积层、池化层的输出维度计算公式

卷积层Conv的输入:高为h、宽为w,卷积核的长宽均为kernel,填充为pad,步长为Stride(长宽可不同,分别计算即可),则卷积层的输出维度为:

其中上开下闭开中括号表示向下取整。

此模型维度计算过程入下:

conv1: ( 32 - 5 +  2 * 0)  / 1 + 1 = 28

pool1 : ( 28 - 2 )/ 2 + 1 = 14

conv2: (14  - 5 + 2 * 0 ) / 1 +1 =10

poo2: (10 -2) / 2 + 1 = 5

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值