手写数字识别模型识别自己的图片

旨在说明如何将训练好的模型运用到自己的图片中。全连接层神经网络模型搭建、训练等代码如下:// An highlighted blockvar foo = 'bar';#MNIST数据集是一个手写字体数据集,包含0-9这10个数字,#其中有55000张训练集,10000张测试集,5000张验证集,图片大小是28x28灰度图import torchfrom torch import nn...
摘要由CSDN通过智能技术生成

本文旨在说明如何将训练好的模型运用到自己的图片中。

1、全连接层神经网络模型搭建、训练等代码如下:

// An highlighted block
var foo = 'bar';
#MNIST数据集是一个手写字体数据集,包含0-910个数字,
#其中有55000张训练集,10000张测试集,5000张验证集,图片大小是28x28灰度图
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import net
import numpy as np
import deal_handwritten_numeral as dhn
#定义一些超参数,
batch_size = 64
learning_rate = 1e-2
num_epoches = 20

#先定义数据预处理
def data_tf(x):
    x = np.array(x, dtype='float32') / 255
    x = (x - 0.5) / 0.5 # 标准化,这个技巧之后会讲到
    # x = x.reshape((-1,)) # 拉平
    x = torch.from_numpy(x)
    return x
   
 #下载训练集
train_dataset = datasets.MNIST(root='data',train=True, transform=data_tf, download=True)
test_dataset = datasets.MNIST(root='data', train=False, transform=data_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size= batch_size, shuffle=False)


#导入网络,定义损失函数和优化方法,模型已在net.py里面定义过了
model = net.simple_Batch_active_Net(28*28, 300, 100, 10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
#注意这里是导入之前训练的模型参数,如果第一次训练请注释掉
model.load_state_dict(torch.load('model_mnist.pth'))

# 开始训练
for n in range(10):
    train_loss = 0
    train_acc = 0
    eval_loss = 0
    eval_acc = 0
    model.train()
    for im, label in train_loader:#每次取出64个数据,在之前就定义好了
        im = Variable(im)
        label = Variable(label)
        # 前向传播
        out = model(im)
        loss = criterion(out, label)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 记录误差
        train_loss += loss.item()#对张量对象可以用item得到元素值
        # 计算分类的准确率
        _, pred = out.max(1)
        num_correct = (pred == label).sum().item()#对张量对象可以用item得到元素值
        acc = num_correct / im.shape[0]#预测正确数/总数,对于这个程序由于小批量设置的是64,所以总数为64
        train_acc += acc#计算总的正确率,以便求平均值
    print('epoch: {}, Train Loss: {:.6f}, Train Acc: {:.6f}'
          .format(n, train_loss / len(train_loader), train_acc / len(train_loader)))
    torch.save(model.state_dict(), 'model_mnist.pth')
    #进入测试阶段
    model.eval()
    for im, label in test_loader:
        # print
  • 18
    点赞
  • 126
    收藏
    觉得还不错? 一键收藏
  • 16
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值