pytorch实现手写数字识别

coding=utf-8

“”"
author:lei
function: 使用pytorch完成手写数字的识别
“”"

import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import torch
import os

BATCH_SIZE = 512
TEST_BATCH_SIZE = 512

1、准备数据集

def get_dataloader(train=True, batch_size=BATCH_SIZE):
transform_fn = Compose([
ToTensor(),
Normalize(mean=(0.1307,), std=(0.3081,)) # mean和std的形状和通道数相同
])
dataset = MNIST(root="./data", train=True, transform=transform_fn)
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, )

return data_loader

2、构建模型

class MnistModel(nn.Module):
def init(self):
super(MnistModel, self).init()
self.fc1 = nn.Linear(12828, 28)
self.fc2 = nn.Linear(28, 10)

def forward(self, input):
    """
    :param input: [batch_size, 1, 28, 28]
    :return:
    """
    # 1、修改形状  最好别直接写batch_size,如果batch_size进行改变,这里会出错;可用input.size[0]或者 -1 进行代替
    x = input.view([input.size()[0], 1*28*28])

    # 2、进行全连接的操作
    x = self.fc1(x)

    # 3、进行激活函数的处理,形状没有变化
    x = F.relu(x)

    # 4、输出层
    out = self.fc2(x)

    return F.log_softmax(out)

model = MnistModel()
optimizer = Adam(model.parameters(), lr=0.001)

判断模型是否存在,模型的加载

if os.path.exists("./model/model.pkl"):
model.load_state_dict(torch.load("./model/model.pkl"))
optimizer.load_state_dict(torch.load("./model/optimizer.pkl"))

def train(epoch):
“”"
实现训练的过程
:param epoch:
:return:
“”"
data_loader = get_dataloader()

for idx, (input, target) in enumerate(data_loader):
    output = model(input)  # 调用模型,得到预测值
    loss = F.nll_loss(output, target)  # 得到损失
    optimizer.zero_grad()  # 梯度置为0
    loss.backward()  # 反向传播
    optimizer.step()  # 梯度进行更新

    if idx % 10 == 0:
        print(epoch, idx, loss.item())

    # 模型的保存
    if idx % 100 == 0:
        # 保存模型参数、保存模型优化器
        torch.save(model.state_dict(), "./model/model.pkl")
        torch.save(optimizer.state_dict(), "./model/optimizer.pkl")

def test():
test_dataloader = get_dataloader(train=False, batch_size=TEST_BATCH_SIZE)

loss_list = []
acc_list = []

for idx, (input, target) in enumerate(test_dataloader):
    with torch.no_grad():
        output = model(input)  # 输出[batch_size, 10]
        cur_loss = F.nll_loss(output, target)
        # output [batch_size, 10] target:[batch_size]
        pred = output.max(dim=-1)[-1]  # 获取到最大值的位置
        cur_acc = pred.eq(target).float().mean()  # 将预测值和真实值进行比较,返回的结果为bool类型,将bool类型转换为浮点类型并求均值
        acc_list.append(cur_acc)  # 将每个batch放到acc_list中
        loss_list.append(loss_list)

print("平均准确率,平均损失:", np.mean(acc_list), np.mean(loss_list))  # 求均值

if name == ‘main’:
# for i in range(3): # 训练三轮
# train(i)

# 进行测试
test()

# loader = get_dataloader(train=False)
#
# for input, label in loader:
#     print(label.size())  # torch.Size([128])
#     break
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值