记录超简单超基本数字识别

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.autograd import Variable
from model import Net  # 自定义神经网络模型
from PIL import Image

import matplotlib.pyplot as plt
import numpy as np
# 定义网络结构,损失函数和优化器
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 加载训练数据集
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

'''
# 训练模型方法一
for epoch in range(5):
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images.view(-1, 28*28))
        labels = Variable(labels)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
'''
# 训练模型方法二
for epoch in range(5):  # 训练5个epoch
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# 保存训练好的模型
torch.save(model.state_dict(), 'model.pth')

# 加载模型
model = Net()
model.load_state_dict(torch.load('model.pth'))
model.eval()

# 处理用户上传的图片
def load_image(image_path):
    image = Image.open(image_path).convert('L')  # 转为灰度图像
    transformation = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor()
    ])
    image = transformation(image).unsqueeze(0)
    return image

'''
============================
#进行数字识别 (不显示样本图像版本代码)

def predict_digit(image_path):
    image = load_image(image_path)
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output.data, 1)
    return predicted.item()
# 使用上传图片进行数字识别
image_path = '数字.png'  # 上传的图片路径
predicted_digit = predict_digit(image_path)
print('识别结果为:', predicted_digit)
==============================
'''

#========以下代码可与上面多行注释(不显示样本图像版)替换=========

# 进行数字识别
def predict_digit(image_path):
    image = load_image(image_path)
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)

        # 显示图片
    image_tensor = image.squeeze(0)  # 去掉batch维度
    image_numpy = image_tensor.numpy().transpose((1, 2, 0))  # 转换numpy数组并改变维度顺序
    image_numpy = (image_numpy * 255).astype(np.uint8)  # 现在np已经定义了

    plt.imshow(image_numpy, cmap='gray')
    plt.title(f'Predicted Digit: {predicted.item()}')
    plt.axis('off')  # 不显示坐标轴
    plt.show()

    return predicted.item()
#==========以上注释是不显示样本的代码段============


# 使用上传图片进行数字识别
image_path = '数字5.png'  # 上传的图片路径
predicted_digit = predict_digit(image_path)
print('本次识别的结果为:', predicted_digit)

 运行结果

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值