Pytorch实战之MNIST识别

转载请注明作者和出处: http://blog.csdn.net/john_bh/

在本文中,将在PyTorch中构建一个简单的卷积神经网络,并在MNIST数据集训练它识别手写数字。

1. MNIST Data

MNIST包含70,000张手写数字图像: 60,000张用于训练,10,000张用于测试。图像是灰度的,28x28像素的,并且居中的,以减少预处理和加快运行。MNIST Data DownLoad
在这里插入图片描述

2. Load Data

新建 DealData.py文件,处理MNIST Data, 分为两种处理方法:

  1. 先下载MNIST数据集到本地,自己编写处理代码;
  2. 使用 torchvision.datasets 处理MNIST。

代码实现如下:

# -*- coding:utf-8 -*-
import os
import gzip
import numpy as np
from torchvision import transforms
from torchvision.datasets import mnist
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

def load_data_minist(data_folder):
    files = [
      'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
      't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']

    paths = []
    for fname in files:
        paths.append(os.path.join(data_folder,fname))

    with gzip.open(paths[0], 'rb') as lbpath:
        y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(paths[1], 'rb') as imgpath:
        x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)

    with gzip.open(paths[2], 'rb') as lbpath:
        y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(paths[3], 'rb') as imgpath:
        x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)

    return (x_train, y_train), (x_test, y_test)

def load_data(data_folder,data_name,label_name):
    """
    :param data_folder: 文件目录
    :param data_name: 数据文件名
    :param label_name: 标签数据文件名
    :return:
    """
    with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据
        y = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:
        x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y), 28, 28)

    return (x, y)

class DealDataSet():
    """
    """
    def __init__(self,folder,data_name,label_name,transform):
        (X_set,Y_set)=load_data(folder,data_name,label_name)# 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
        self.X_set=X_set
        self.Y_set=Y_set
        self.transform=transform

    def __getitem__(self,index):
        img,target=self.X_set[index],int(self.Y_set[index])
        if self.transform is not None:
            img=self.transform(img)
        return img,target

    def __len__(self):
        return len(self.X_set)

def load_data_torch():
    # 预处理=>将各种预处理组合在一起
    data_tf = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])])

    # 使用内置函数下载mnist数据集
    train_set = mnist.MNIST('./MNIST_data', train=True, transform=data_tf, download=True)
    test_set = mnist.MNIST('./MNIST_data', train=False, transform=data_tf, download=True)

    return (train_set,test_set)

def show_example_img():
    train_set, test_set = load_data_torch()
    train_data = DataLoader(train_set, batch_size=64, shuffle=True)
    test_data = DataLoader(test_set, batch_size=64, shuffle=False)
    examples = enumerate(test_data)
    batch_idx, (example_data, example_targets) = next(examples)

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
        plt.title("Ground Truth: {}".format(example_targets[i]))
        plt.xticks([])
        plt.yticks([])
    plt.show()

def show_example_targets():
    train_set, test_set = load_data_torch()
    train_data = DataLoader(train_set, batch_size=64, shuffle=True)
    test_data = DataLoader(test_set, batch_size=64, shuffle=False)
    examples = enumerate(test_data)
    batch_idx, (example_data, example_targets) = next(examples)
    print(example_targets)
    print(example_data.shape)


if __name__=="__main__":
    show_example_img()
    show_example_targets()

show_example_img()结果:
在这里插入图片描述
show_example_targets()结果:

tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 2, 3, 5, 1, 2, 4, 4, 6, 3, 5, 5, 6, 0, 4, 1, 9, 5, 7, 8, 9, 3])
torch.Size([64, 1, 28, 28])

3. 定义CNN

新建 model.py文件,编写CNN,具体代码如下:

# -*- coding:utf-8 -*-
from torch import nn

# define Net
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3),  # 16, 26 ,26
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True))

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3),  # 32, 24, 24
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2))  # 32, 12,12     (24-2) /2 +1

        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3),  # 64,10,10
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))

        self.layer4 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3),  # 128,8,8
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2))  # 128, 4,4

        self.fc = nn.Sequential(
            nn.Linear(128 * 4 * 4, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 10))

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

4. Train & Test

这里写了一个完整的训练代码,包括:参数的设定,加载MNIST数据,使用网络,定义损失函数和优化器,训练网络,保存模型,保存训练过程中的loss和acc并做可视化,设置验证集(测试集)。具体代码如下:

# -*- coding:utf-8 -*-
from torch.utils.data import DataLoader
import torch
from torch import nn
from torch.autograd import Variable
from torch import optim
from DealDataSet import load_data_torch

from model import Net
import os
import matplotlib.pyplot as plt
import numpy as np

#Parameters
max_epoch=3
learning_rate=1e-1
momentum=0.9
batch_size=64
display_step=100
mode_dir="./models/model1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Network Parameters
num_input=784 # MNIST data input (img shape: 28*28)
num_classes=10 # MNIST total classes (0-9 digits)
dropout=0.75 # Dropout, probability to keep units

#load data
train_set,test_set=load_data_torch()
train_data = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_data = DataLoader(test_set, batch_size=batch_size, shuffle=False)

# Initialize the model
net = Net()
net.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate,momentum=momentum)

def draw_train_process(title,iters,costs,accs,label_cost,lable_acc):
    plt.title(title, fontsize=24)
    plt.xlabel("iter", fontsize=20)
    plt.ylabel("acc(\%)", fontsize=20)
    plt.plot(iters, costs,color='red',label=label_cost)
    plt.plot(iters, accs,color='green',label=lable_acc)
    plt.legend()
    plt.grid()
    plt.show()

def showTestResult(test_data):
    examples = enumerate(test_data)
    batch_idx, (example_data, example_targets) = next(examples)
    with torch.no_grad():
        output = net(example_data)
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
        plt.title("Prediction: {}".format(
            output.data.max(1, keepdim=True)[1][i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

def train():
    global net
    # 如果有模型则加载
    if os.path.exists(mode_dir):
        print("Use pre_mode")
        net.load_state_dict(torch.load(os.path.join(mode_dir,'model.pth')))
        optimizer.load_state_dict(torch.load(os.path.join(mode_dir,'optimizer.pth')))
    else:
        os.makedirs(os.path.join(os.getcwd(),mode_dir))

    train_losses=[]
    train_acces=[]
    # epoch_size = int(mnist.train.num_examples / batch_size)

    for epoch in range(max_epoch):
        net = net.train() # 执行网络训练
        for step, (img,label) in enumerate(train_data):
            img = Variable(img)
            img = img.to(device)
            label = label.to(device)
            label = Variable(label)

            optimizer.zero_grad()  # Gradient clear
            output = net(img)
            loss = criterion(output, label) # Calculate loss
            loss.backward() # Backpropagation
            optimizer.step() # Update gradient

            _, pred = torch.max(output.data,1) # Calculation accuracy
            num_correct = (pred == label).sum().item()
            train_acc = num_correct / img.shape[0]

            # Record loss and accuracy
            train_loss = loss.item()
            train_losses.append(train_loss )
            train_acces.append(train_acc*100 )

            if(step+1)%display_step ==0 or (step+1) == len(train_data):
                print('Epoch:{} [{}/{}], Loss:{:.4f}, Acc:{}'.format(epoch,step+1,len(train_data),train_loss,train_acc))

        # Save model
        torch.save(net.state_dict(), os.path.join(mode_dir,'model.pth'))
        torch.save(optimizer.state_dict(),os.path.join(mode_dir,'optimizer.pth'))

        print('Epoch {}: Train Loss: {} Train  Accuracy: {} '.format(epoch + 1, np.mean(train_losses), np.mean(train_acces)))
        # Set up validation set
        print("############## val this epoch ##############")
        test()

    # Draw loss and acc charts
    draw_train_process('training', range(len(train_losses)), train_losses, train_acces, 'training loss', 'training acc')

def test():
    net.eval()
    test_loss=[]
    test_acc=[]

    with torch.no_grad(): # The network does not update the gradient during evaluation
        for i,(img,label) in enumerate(test_data):
            output=net(img)
            test_loss.append(criterion(output, label))
            _,test_pred=torch.max(output.data,1) # test_pred = output.data.max(1)[1]
            num_correct = (test_pred == label).sum().item()
            test_acc.append(num_correct / img.shape[0])
            # test_acc += label.size(0)
            # test_acc.append ((test_pred == label).sum().item()) #test_acc.append(test_pred.eq(label.data.view_as(test_pred)).sum())

        print("avg loss:{}, avg acc:{}".format(np.mean(test_loss),np.mean(test_acc)))

    #Draw loss and acc charts
    draw_train_process("Test", range(len(test_loss)), test_loss, test_acc, "testing loss", "testing acc")
    showTestResult(test_data)

if __name__=="__main__":
    train()

Epoch:0 [100/938], Loss:0.0004, Acc:1.0
Epoch:0 [200/938], Loss:0.0572, Acc:0.984375
Epoch:0 [300/938], Loss:0.0010, Acc:1.0
Epoch:0 [400/938], Loss:0.0344, Acc:0.984375
Epoch:0 [500/938], Loss:0.0047, Acc:1.0
Epoch:0 [600/938], Loss:0.0010, Acc:1.0
Epoch:0 [700/938], Loss:0.0003, Acc:1.0
Epoch:0 [800/938], Loss:0.0005, Acc:1.0
Epoch:0 [900/938], Loss:0.0439, Acc:0.984375
Epoch:0 [938/938], Loss:0.0022, Acc:1.0
Epoch 1: Train Loss: 0.011573510118567332 Train Accuracy: 0.9966184701492538
############## val this epoch ##############
avg loss:0.030577583238482475, avg acc:0.9914410828025477
Epoch:1 [100/938], Loss:0.0000, Acc:1.0
Epoch:1 [200/938], Loss:0.0001, Acc:1.0
Epoch:1 [300/938], Loss:0.0247, Acc:0.984375
Epoch:1 [400/938], Loss:0.0002, Acc:1.0
Epoch:1 [500/938], Loss:0.0114, Acc:1.0
Epoch:1 [600/938], Loss:0.1161, Acc:0.984375
Epoch:1 [700/938], Loss:0.0000, Acc:1.0
Epoch:1 [800/938], Loss:0.0003, Acc:1.0
Epoch:1 [900/938], Loss:0.0000, Acc:1.0
Epoch:1 [938/938], Loss:0.0003, Acc:1.0
Epoch 2: Train Loss: 0.011195699965508517 Train Accuracy: 0.9966767723880597
############## val this epoch ##############
avg loss:0.03695254400372505, avg acc:0.9913415605095541
Epoch:2 [100/938], Loss:0.1651, Acc:0.984375
Epoch:2 [200/938], Loss:0.0689, Acc:0.984375
Epoch:2 [300/938], Loss:0.0001, Acc:1.0
Epoch:2 [400/938], Loss:0.0031, Acc:1.0
Epoch:2 [500/938], Loss:0.0114, Acc:1.0
Epoch:2 [600/938], Loss:0.3162, Acc:0.9375
Epoch:2 [700/938], Loss:0.0206, Acc:0.984375
Epoch:2 [800/938], Loss:0.0003, Acc:1.0
Epoch:2 [900/938], Loss:0.0047, Acc:1.0
Epoch:2 [938/938], Loss:0.0562, Acc:0.96875
Epoch 3: Train Loss: 0.011290731870277657 Train Accuracy: 0.9967073116560057
############## val this epoch ##############
avg loss:0.027084697037935257, avg acc:0.9937300955414012

可视化中间结果,每一个epoch结果如下:
在这里插入图片描述在这里插入图片描述
在这里插入图片描述在这里插入图片描述

在这里插入图片描述在这里插入图片描述

训练过程可视化:
在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值