手写字体识别(2) 数据加载及网络搭建


github地址:
https://github.com/Huyf9/mnist_pytorch/

数据加载

数据加载部分在torch中有现成的Dataset类以及Dataloader类。我们利用Dataset类来创建自己的Dataset。
Dataset类需要重写其中的三个函数:

__ init() # 创建初始化参数
__ getitem() # 根据索引返回图片信息与标签信息
__ len() # 返回数据的长度,方便Dataloader对数据进行批操作

import torch
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class Mydataset(Dataset):
    def __init__(self, pic_path, label_path):
        super(Mydataset, self).__init__()
        pics = open(pic_path, 'r', encoding='utf-8').readlines()
        labels = open(label_path, 'r', encoding='utf-8').readlines()
        self.pics = pics
        self.labels = labels

    # 索引你的数据和标签并返回
    def __getitem__(self, index):
        pic = np.array(Image.open(self.pics[index].strip()))
        label = int(self.labels[index].strip())

        # 转成tensor格式
        pic = torch.tensor(pic)
        pic = torch.unsqueeze(pic, 0).float()
        label = torch.tensor(label)
        return pic, label

    # 返回数据长度
    def __len__(self):
        return len(self.pics)

由于在卷积操作中,支持的图片维度为[B: batch, C: channel, H: height, W: width],因此在__getitem()函数中我们对图片利用torch.unsqueeze()函数进行升维操作。

模型建立

我们建立一个简单的多层感知机模型与简单的卷积模型来分别进行训练。

多层感知机模型

多层感知机也叫人工神经网络(ANN),除了输入输出层,它中间可以有多个隐层,最简单的MLP只含一个隐层,即三层的结构。我们使用最简单的三层感知机。

import torch.nn as nn
import torch

class MlpNet(nn.Module):
    def __init__(self):
        super(MlpNet, self).__init__()
        # 28 x 28
        self.fc1 = nn.Linear(784, 392)
        self.sigmoid1 = nn.Sigmoid()
        self.fc2 = nn.Linear(392, 10)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        # x = torch.flatten(torch.squeeze(x, 0))
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = self.sigmoid1(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x


if __name__ == '__main__':
    data = torch.rand([1, 28, 28])
    data = torch.unsqueeze(data, 0)
    print(data.shape)
    net = MlpNet()
    print(net)
    print(net(data).shape)

卷积模型

多层感知机虽然可以比较好的识别手写字体,但遇到旋转的图片时精度下降,此时可以使用卷积神经网络来提升精度。卷积网络由于权值共享使得其具有平移不变性,由于maxpooling操作的存在,又使其具有旋转不变性

import torch.nn as nn
import torch

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        # 28 x 28
        self.Conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5)
        self.pool1 = nn.MaxPool2d(2, 1)
        self.relu1 = nn.ReLU()

        self.Conv2 = nn.Conv2d(in_channels=10, out_channels=10, kernel_size=5)
        self.pool2 = nn.MaxPool2d(2, 1)
        self.relu2 = nn.ReLU()

        self.Conv3 = nn.Conv2d(in_channels=10, out_channels=10, kernel_size=5)
        self.pool3 = nn.MaxPool2d(2, 1)
        self.relu3 = nn.ReLU()

        self.drop = nn.Dropout(0.8)  # 将80%的神经元失活
        self.fc = nn.Linear(10*13*13, 10)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.Conv1(x)
        x = self.pool1(x)
        x = self.relu1(x)

        x = self.Conv2(x)
        x = self.pool2(x)
        x = self.relu2(x)

        x = self.Conv3(x)
        x = self.pool3(x)
        x = self.relu3(x)

        x = x.view(-1, 10*13*13)  # 将四维矩阵纬度拉成一维  [Batch, Channel, H, W]
        x = self.drop(x)
        x = self.fc(x)
        x = self.softmax(x)

        return x


if __name__ == '__main__':
    data = torch.rand([1, 28, 28])
    data = torch.unsqueeze(data, 0)
    print(data.shape)
    net = ConvNet()
    print(net)
    print(net(data).shape)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值