基于PyTorch的MNIST手写数字识别(配置手写板使用)
代码详见:https://github.com/xiaozhou-alt/CNN_MNIST
文章目录
一、项目介绍
本项目实现了一个基于PyTorch的MNIST手写数字识别系统,包含完整的训练、评估和可视化界面功能。系统使用卷积神经网络(CNN)模型,能够识别0-9的手写数字,并提供了友好的GUI界面进行实时手写识别。
二、数据集介绍
数据集获取:代码文件中包含了数据集以及下载的代码实现方式
使用标准MNIST数据集:
- 训练集:60,000张28x28灰度手写数字图像
- 测试集:10,000张28x28灰度手写数字图像
- 数字范围:0-9
三、项目实现
1.环境准备
- Python 3.7+
必须库:
pip install torch torchvision numpy pandas matplotlib seaborn PyQt5
2.项目文件夹结构
CNN_MNIST/
├── MINIST-master/data/ # MNIST数据集
│ ├── train-images-idx3-ubyte.gz # (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
│ ├── train-labels-idx1-ubyte.gz # (29 KB, 解压后 60 KB, 包含 60,000 个标签)
│ ├── t10k-images-idx3-ubyte.gz # (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
│ └── t10k-labels-idx1-ubyte.gz # (5KB, 解压后 10 KB, 包含 10,000 个标签)
├── utils/
│ └── data_loader.py # 数据加载和预处理
├── train.py # 模型训练脚本
├── evaluate.py # 模型评估脚本
├── handwriting_ui.py # 手写识别GUI界面
└── final_model.pth # 训练好的模型权重
3.数据预处理
因为直接使用数据集进行训练的手写板效果一坨
为了提高数据的多样性和模型的鲁棒性,此处使用三种方法对数据进行增强处理:
(1) 随机旋转 (Random Rotation)
- 旋转角度范围:±15度
- 使用transforms.functional.rotate实现
- 随机选择旋转角度
def random_rotate(self, img, degree=15):
angle = random.uniform(-degree, degree)
return transforms.functional.rotate(img, angle)
(2) 随机缩放 (Random Scale)
- 缩放比例范围:0.9-1.1倍
- 保持图像宽高比不变
- 使用transforms.functional.resize实现
def random_scale(self, img, scale_range=(0.9, 1.1)):
scale = random.uniform(*scale_range)
h, w = img.shape[-2:]
new_h, new_w = int(h * scale), int(w * scale)
return transforms.functional.resize(img, (new_h, new_w))
(3) 随机平移 (Random Shift)
- 最大平移距离:±2像素
- 水平和垂直方向独立随机平移
- 使用transforms.functional.affine实现
def random_shift(self, img, max_shift=2):
h_shift = random.randint(-max_shift, max_shift)
v_shift = random.randint(-max_shift, max_shift)
return transforms.functional.affine(img, angle=0, translate=(h_shift, v_shift), scale=1.0, shear=0)
4.开始训练!
(1) 数据加载
- 使用load_mnist_images和load_mnist_labels函数从 gzip 压缩文件中读取 MNIST 数据集
- 图像数据被读取为 uint8 格式并 reshape 为28x28的二维数组
- 标签数据直接从文件中读取
(2) 数据转换
- 使用transforms.Compose组合多个转换操作
- ToTensor()将图像转换为 PyTorch 张量
- Normalize()对数据进行标准化处理
def load_mnist_images(filename):
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
return data.reshape(-1, 28, 28)
def load_mnist_labels(filename):
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=8)
return data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
(3) 模型定义
- 使用两层卷积和两层全连接
- 包含 ReLU 激活和 MaxPooling 操作
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(-1, 64 * 7 * 7)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
(4) 训练过程
- 使用 Adam 优化器和交叉熵损失函数
- 每100个 batch 打印一次训练进度
- 每轮训练后保存模型检查点,最后保存最优模型
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}')
# ... existing code ...
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
# ... existing code ...
}, f'./model_epoch_{epoch}.pth')
torch.save(model.state_dict(), './final_model.pth')
(5) 评估测试
计算在测试集上的准确率(Accuracy)、召回率(Recall)、F1分数(f1-score)对模型进行评估
def evaluate():
correct = 0
total = 0
predictions = []
# 初始化混淆矩阵
confusion_matrix = torch.zeros(10, 10, dtype=torch.int64)
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
predictions.extend(predicted.numpy())
# 构建混淆矩阵
for t, p in zip(labels.view(-1), predicted.view(-1)):
confusion_matrix[t.long(), p.long()] += 1
# 计算各项指标
accuracy = 100 * correct / total
precision = confusion_matrix.diag() / confusion_matrix.sum(0).float()
recall = confusion_matrix.diag() / confusion_matrix.sum(1).float()
f1 = 2 * precision * recall / (precision + recall)
输出结果如下所示:
四、结果展示
训练损失如下所示:
测试集上的混淆矩阵如下所示:
手写板的 UI 界面如下所示:
MNIST-手写数字识别手写板展示
如果你喜欢我的文章,不妨给小周一个免费的点赞和关注吧!