基于卷积神经网络的MNIST上正确率99.5%+baseline构建及调试经验分享(pytorch实现)

本文分享了基于PyTorch构建的卷积神经网络在MNIST手写数字识别任务上的优化过程,包括模型构建、调试和性能提升策略。通过调整模型结构、优化器、batchsize等,最终实现99.5%以上的测试正确率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

内容简介

MNIST手写数字识别任务是入门神经网络的经典任务。构建一般的二层全连接神经网络或者是简单的卷积神经网络均可以轻松达到正确率99%加,本文在此基础之上分享进一步的模型改进逻辑,并给出对应实验结果供读者参考。

Baseline说明

首先给出可以直接运行baseline,要求安装pytorch,visdom等,有无GPU均可,若有GPU显存占用大概为0.7G。(建议支持GPU,笔者运行的时候i7-9750H 直接顶满了,最佳测试集正确率为99.53%)

模块1 引用库、定batchsize、定训练测试集

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
import time
import visdom

batch_size = 256
if torch.cuda.is_available():
    use_cuda = True
else:
    use_cuda = False

train_dataset = datasets.MNIST(root='./data/',
                               train=True,
                               transform=transforms.Compose([transforms.RandomRotation(10),
                                                             #transforms.RandomCrop(22,padding=3,pad_if_needed=True),
                                                             transforms.ToTensor()]),
                               download=True)
test_dataset = datasets.MNIST(root='./data/',
                              train=False,
                              
                              transform=transforms.Compose([#transforms.RandomCrop(22,padding=3,pad_if_needed=True),
                                                             transforms.ToTensor()]))                                                        
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=0,)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=0,)

模块2 初始化visdom 此处给出train_loss等三个量和epoch变化的关系

vis = visdom.Visdom(port = 8008) #python -m visdom.server -p 8008  建立visdom本地虚拟服务器的代码 cmd下运行
win_curve = vis.line(
    X = np.array( [0] ),
    Y = np.array( [0] ),
    opts = dict(
           xlabel=
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值