【新手向】【Kaggle】基于CNN的简易MNIST数字识别器PyTorch实现(99.5%准确率)

【新手向】【Kaggle】基于CNN的简易MNIST数字识别器PyTorch实现(99.5%准确率)

MNIST中的数据


原文链接: ​[99.5%]Introduction to CNN w/ PyTorch | Kaggle之所以打算在CSDN上水一遍是因为自己实在没活儿,整个这个水一行活动经历
原竞赛: Digit Recognizer | Kaggle

前言

这是我第一次在平台上写文(文笔只能以悲剧来形容),码龄约等于0,也是刚入门ML,还是有点紧张捏。如果你发现了任何错误不严谨或有任何意见,别犹豫,leave your comments below!
本文要求读者对深度学习和PyTorch有最基本的了解和使用经验,旨在为读者朋友们提供一个数字识别器的简单又高效的PyTorch实现。本文也会简单的介绍一些重要的提高深度学习性能的入门技巧供大家学习。虽然本人也是萌新,但如果有问题的话火速提问,我会尽力回答,读者随后可再多方询问研究:讨论能给参与者们一个很好的学习机会。

1. MNIST数字识别器简介

1.1. 文章简介

与计算机视觉(CV)相关的研究已经有60年的历史,但由于已标签的数据量和计算能力的爆炸性发展,它最近才在工业界获得大量关注。图像分类是计算机视觉中最基本的任务之一。目前的很多应用场景下我们通常会使用深度学习模型来分析图像(在拥有大量高质量复杂数据的前提下,深度学习方案通常可以替代人工特征工程,更高效的部署模型同时拥有较高的准确度)。图像分类的先驱模型之一是由Yann LeCun等人在1989年提出的LeNet,一个用于数字识别的卷积神经网络(CNN)。在本教程中,我们将尝试实现一个类似于LeNet的CNN,以实现同样的任务–数字识别。本文将围绕 Digit Recognizer | Kaggle新手练习赛进行编程。

1.2. MNIST 数据集

MNIST数据库是一个大型的手写数字数据库(由0到9之间手绘数字的黑白图像组成)。我们将使用Kaggle提供的训练数据集(MNIST中的一部分数据)来训练我们的CNN模型。
让我们首先了解Kaggle提供的MNIST数据集的结构。运行下面的代码将显示我们训练集的前5行。

import pandas as pd
trainset = pd.read_csv('../input/digit-recognizer/train.csv')
trainset.head()

输出:
trainset.head()
注:代码中使用的路径为Kaggle比赛文件路径。

1.3. 数据集使用规则

在这次比赛中使用MNIST原始数据集进行训练会被认作作弊,因为MNIST原数据集包含测试集。我们只能使用Kaggle提供的训练集进行训练。有一点需要注意的是,使用测试集来评估模型的性能并没有被禁止。我们可以基于原数据集创建一个标注的测试数据集来评估模型的性能。更多细节见5.2。在测试集上检查准确性 (请不要在Kaggle对应竞赛上提交已标注的测试集)

2. 数据集的准备和预处理

2.1. 导入包

import numpy as np # 数据处理
import pandas as pd # 数据处理, CSV文件 I/O (例如 pd.read_csv)
import torch # PyTorch
from torch.utils.data import Dataset, DataLoader # 数据载入
from torch.utils.data.sampler import SubsetRandomSampler # 数据预处理
from torchvision import transforms # 图像数据预处理
import torch.nn as nn # 神经网络
import torch.nn.functional as F # 函数
import torch.optim as opt # 优化器

2.2. 导入数据

首先,让我们读入训练集和测试集的csv文件。数据可在此处下载。

train_set = pd.read_csv('../input/digit-recognizer/train.csv')
test_set = pd.read_csv('../input/digit-recognizer/test.csv')

打乱数据并选择出用于训练集与验证集。

VALID_SIZE = 0.1 # 用于验证集的数据比例

num_train = len(train_set)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(VALID_SIZE * num_train))
train_indices, valid_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)

print(f'training set length: {
     len(train_indices)}')  # 37800
print(f'validation set length: {
     len(valid_indices)}')  # 4200

完成随机采样后,定义一个名为DatasetMNIST的Torch数据集的子类,这样Torch的Dataloader就可以正确使用我们的MNIST数据集。

class DatasetMNIST(Dataset):
    def __init__(self, data, transform=None, labeled=True):
        self.data = data
        self.transform = transform
        self.labeled = labeled

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):  # 重写,定义数据载入行为
        item = self.data.iloc[index]
        if self.labeled:  # 处理已标注数据
            x = item[1:].values.astype(np.uint8).
  • 5
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值