【邱希鹏】nndl-chap5-数字识别(pytorch)

本文介绍了使用PyTorch实现数字识别的步骤,包括导入所需包、加载数据、构建CNN模型,以及详细展开的训练与测试过程。在训练和测试阶段,展示了数据集的尺寸和部分预测结果。
摘要由CSDN通过智能技术生成

1. 导入包

import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import torch.nn.functional as F
import numpy as np
from torchvision import datasets, transforms
learning_rate = 1e-4
keep_prob_rate = 0.7 #
max_epoch = 3
BATCH_SIZE = 50

DOWNLOAD_MNIST = False
if not(os.path.exists('MNIST')) or not os.listdir('MNIST'):
    # not mnist dir or mnist is empyt dir
    DOWNLOAD_MNIST = True

2. 导入数据

train_data = torchvision.datasets.MNIST(root='./', train = True, download=DOWNLOAD_MNIST,
                                        transform = torchvision.transforms.Compose([
                                            transforms.ToTensor(),
                                        ]))
train_loader = Data.DataLoader(dataset = train_data ,batch_size= BATCH_SIZE ,shuffle= True)

test_data = torchvision.datasets.MNIST(root = './', train = False,
                                      transform = transforms.Compose([
                                          transforms.ToTensor(),
                                      ]))
test_loader = Data.DataLoader(dataset = test_data, batch_size = BATCH_SIZE, shuffle = True)

test_x = Variable(torch.unsqueeze(test_data.test_data,dim  = 1),volatile = True).type(torch.FloatTensor)[:500]/255.
test_y = test_data.test_labels[:500].numpy()
print(test_x.shape)
print(test_y.shape)

# torch.Size([500, 1, 28, 28])
# (500,)

3. CNN模型

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d( # ???
                # patch 7 * 7 ; 1  in channels ; 32 out channels ; ; stride is 1
                # padding style is same(that means the convolution opration's input and output have the same size)
                in_channels = 1     ,  
                out_channels = 32   ,
                kernel_size = 7     ,
                stride = 1          ,
                padding = 0    ,
            ),
            nn.ReLU(),        # activation function
            nn.MaxPool2d(2),  # pooling operation
        )
        self.conv2 = nn.Sequential( # ???
            # line 1 : convolution function, patch 5*5 , 32 in channels ;64 out channels; padding style is same; stride is 1
            # line 2 : choosing your activation funciont
            # line 3 : pooling operation function.
            nn.Conv2d(
                in_channels = 32,
                out_channels = 64,
                kernel_size = 5,
                stride = 1,
                padding = 0,
            ),
            nn.ReLU(),
            nn.MaxPool2d(1),
        )
        self.out1 = nn.Linear(7*7*64 , 1024 , bias= True)   # full connection layer one

        self.dropout = nn.Dropout(keep_prob_rate)
        self.out2 = nn.Linear(1024, 10, bias=True)


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(-1, 64*7*7)  # flatten the output of coonv2 to (batch_size ,64 * 7 * 7)    # ???
        out1 = self.out1(x)
        out1 = F.relu(out1)
        out1 = self.dropout(out1)
        out2 = self.out2(out1)
        output = F.softmax(out2)
        return output

4. 训练与测试

def test(cnn):
    global prediction
    y_pre = cnn(test_x)
    _,pre_index= torch.max(y_pre,1)
    pre_index= pre_index.view(-1)
    prediction = pre_index.data.numpy()
    correct  = np.sum(prediction == test_y)
    return correct / 500.0


def train(cnn):
    optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate )
    loss_func = nn.CrossEntropyLoss()
    for epoch in range(max_epoch):
        for step, (x_, y_) in enumerate(train_loader):
            x ,y= Variable(x_),Variable(y_)
            output = cnn(x)  
            loss = loss_func(output, y)   # 标量
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            
            if step != 0 and step % 20 ==0:
                print("=" * 10,step,"="*5,"="*5, "test accuracy is ",test(cnn) ,"=" * 10 )

4.1 训练

cnn = CNN()
train(cnn)
========== 20 ===== ===== test accuracy is  0.25 ==========
========== 40 ===== ===== test accuracy is  0.458 ==========
========== 60 ===== ===== test accuracy is  0.572 ==========
========== 80 ===== ===== test accuracy is  0.624 ==========
========== 100 ===== ===== test accuracy is  0.638 ==========
========== 120 ===== ===== test accuracy is  0.718 ==========
========== 140 ===== ===== test accuracy is  0.744 ==========
========== 160 ===== ===== test accuracy is  0.796 ==========
========== 180 ===== ===== test accuracy is  0.806 ==========
========== 200 ===== ===== test accuracy is  0.808 ==========
========== 220 ===== ===== test accuracy is  0.844 ==========
========== 240 ===== ===== test accuracy is  0.84 ==========
========== 260 ===== ===== test accuracy is  0.864 ==========
========== 280 ===== ===== test accuracy is  0.86 ==========
========== 300 ===== &#
  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值