import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.autograd import Variable
from model import Net # 自定义神经网络模型
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
# 定义网络结构,损失函数和优化器
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# 加载训练数据集
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
'''
# 训练模型方法一
for epoch in range(5):
for i, (images, labels) in enumerate(train_loader):
images = Variable(images.view(-1, 28*28))
labels = Variable(labels)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
'''
# 训练模型方法二
for epoch in range(5): # 训练5个epoch
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 保存训练好的模型
torch.save(model.state_dict(), 'model.pth')
# 加载模型
model = Net()
model.load_state_dict(torch.load('model.pth'))
model.eval()
# 处理用户上传的图片
def load_image(image_path):
image = Image.open(image_path).convert('L') # 转为灰度图像
transformation = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor()
])
image = transformation(image).unsqueeze(0)
return image
'''
============================
#进行数字识别 (不显示样本图像版本代码)
def predict_digit(image_path):
image = load_image(image_path)
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output.data, 1)
return predicted.item()
# 使用上传图片进行数字识别
image_path = '数字.png' # 上传的图片路径
predicted_digit = predict_digit(image_path)
print('识别结果为:', predicted_digit)
==============================
'''
#========以下代码可与上面多行注释(不显示样本图像版)替换=========
# 进行数字识别
def predict_digit(image_path):
image = load_image(image_path)
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output, 1)
# 显示图片
image_tensor = image.squeeze(0) # 去掉batch维度
image_numpy = image_tensor.numpy().transpose((1, 2, 0)) # 转换numpy数组并改变维度顺序
image_numpy = (image_numpy * 255).astype(np.uint8) # 现在np已经定义了
plt.imshow(image_numpy, cmap='gray')
plt.title(f'Predicted Digit: {predicted.item()}')
plt.axis('off') # 不显示坐标轴
plt.show()
return predicted.item()
#==========以上注释是不显示样本的代码段============
# 使用上传图片进行数字识别
image_path = '数字5.png' # 上传的图片路径
predicted_digit = predict_digit(image_path)
print('本次识别的结果为:', predicted_digit)
运行结果