手写数字识别

import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

batch_size = 4
transform = transforms.Compose([
    transforms.ToTensor(),  # 把PIL图像转化为Tensor,W*H*C->C*W*H
    transforms.Normalize((0.1307,), (0.3081,))  # 第一个是均值,第二个是标准差,针对MNIST
])

train_dataset = datasets.MNIST(root='../dataset/',
                               train=True,
                               transform=transform,
                               download=True)
test_dataset = datasets.MNIST(root='../dataset/',
                              train=False,
                              transform=transform,
                              download=True)
train_loader = DataLoader(dataset=train_dataset,           # 可迭代的训练集 每迭代一次 取一个batch_size大小的数据
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=True)


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=10, kernel_size=1)
        self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=3)
        self.conv3 = torch.nn.Conv2d(20, 30, kernel_size=3)
        self.pooling = torch.nn.MaxPool2d(kernel_size=2)
        self.fc = torch.nn.Linear(in_features=120, out_features=10, bias=True)
        # in_features:上层网络神经元的个数
        # out_features:该网络层神经元的个数
        # bias:网络层是否有偏置,默认为True,且维度为[out_features]

    def forward(self, x):      # 前向传播
        batch_size = x.size(0)  # 输入(batch,1,28,28)
        x = F.relu(self.pooling(self.conv1(x)))  # 经第一次卷积池化 (batch,10,14,14)
        x = F.relu(self.pooling(self.conv2(x)))  # (batch,20,6,6)
        x = F.relu(self.pooling(self.conv3(x)))  # (batch,30,2,2)    通道数 30, 图像大小 2x2
        x = x.view(batch_size, -1)
        x = self.fc(x)
        return x


model = Net()

criterion = torch.nn.CrossEntropyLoss()        # 损失函数使用交叉熵
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)  # 优化器实例


def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):   # 遍历训练集
        inputs, target = data
        optimizer.zero_grad()   # 将梯度归零

        # forward+backward+update
        outputs = model(inputs)  # 将数据传入网络进行前向运算
        loss = criterion(outputs, target)   # 得到损失函数
        loss.backward()              # 反向传播计算梯度下降
        optimizer.step()  # 通过梯度做一步参数更新

        running_loss += loss.item()     
        if batch_idx % 300 == 299: 
            print('[%d,%5d]loss:%.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
            running_loss = 0.0

global predicted


def test():
    correct = 0
    total = 0
    with torch.no_grad():  # 不计算梯度
        for data in test_loader:  # test_loader中按batch把数据分开
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            # print('labels=', labels)
            correct += (predicted == labels).sum().item()
    print('Accuracy on test set:%d %% [%d/%d]' % (100 * correct / total, correct, total))


if __name__ == '__main__':
    for epoch in range(1):
        train(epoch)
        test()
    images, labels = next(iter(test_loader))
    image = torchvision.utils.make_grid(images)
    image = image.numpy().transpose(1, 2, 0)
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    image = image * std + mean
    outputs = model(images)
    _, predicted = torch.max(outputs.data, dim=1)
    print("预测值:", predicted)
    print("标签值:", labels)
    fig = plt.figure(figsize=(10, 5))
    plt.title('手写数字识别', fontproperties='stsong')
    fig, axes = plt.subplots(2, 2)
    for i in range(4):
        ax = plt.subplot(2, 2,  i + 1)  # 画多个子图(2*2)
        ax.imshow(np.reshape(images[i], (28, 28)), cmap='binary')  # 显示第index张图像
        title = "label=" + str(torch.max(labels[i]))
        title += ",predict=" + str(predicted[i])  # 构建图片上要显示的title
        ax.set_title(title, fontsize=10)
        ax.set_xticks([])  # 不显示坐标轴
        ax.set_yticks([])
        plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文
        plt.suptitle("手写数字识别", x=0.5, y=1, fontsize=16)
    plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

光着膀子的宁

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值