机器学习-MNIST

前言

MNIST是pytorch框架自带的一个手写0-9分类数据集,其中训练集5w张,测试集1w张,每张图片是28*28像素的单通道图片,本文将用全连接线性神经网络和卷积神经网络两种方式来实现对MNIST数据集的分类。

image.png

所需要的三方库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

加载数据

MNIST数据集是pytorch自带,可通过pytorch直接下载,在下载的同时,通过transforms.ToTensor(),将数据转成tensor数据。并使用DataLoader将数据分割成多个mini-batch,当然这步可做可不做,本文的线性回归方式并没有使用DataLoader分割数据,卷积神经网络使用了DataLoader分割数据,已做对照。

def get_data():
    training_data = datasets.MNIST('../../data', train=True, download=True, transform=transforms.ToTensor())
    test_data = datasets.MNIST('../../data', train=False, download=True, transform=transforms.ToTensor())
    training_data = DataLoader(training_data, batch_size=64, shuffle=True)
    test_data = DataLoader(test_data, batch_size=64, shuffle=True)
    return training_data, test_data

training_data, test_data = get_data()

线性神经网络

构建线性神经网络模型

nn.Sequential 是 pytorch 库中的一个类,它允许通过按顺序堆叠多个层来创建神经网络模型。它提供了一种方便的方式来定义和组织神经网络的层。

每一张图片都是单通道28 * 28的图片。也就是说一个一维矩阵就可以表示一张图片,这张图片有28 * 28个特征,因此第一个隐藏层的输入为28 * 28,输出则可自定义,只是确保之后的输出和输入要一一对应。因为该数据集有10个分类,因此最后的输出为10。并使用RuLU()作为激活函数。

def create_model():
    class NeuralNetwork(nn.Module):
        def __init__(self):
            super(NeuralNetwork, self).__init__()
            self.fc = nn.Sequential(
                nn.Linear(in_features=28 * 28, out_features=512),
                nn.ReLU(),
                nn.Linear(in_features=512, out_features=256),
                nn.ReLU(),
                nn.Linear(in_features=256, out_features=128),
                nn.ReLU(),
                nn.Linear(in_features=128, out_features=64),
                nn.ReLU(),
                nn.Linear(in_features=64, out_features=32),
                nn.ReLU(),
                nn.Linear(in_features=32, out_features=16),
                nn.ReLU
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值