MNIST图像分类任务功能性代码补充
数据集说明
接上一篇博客(基于pytorch的深度学习的MNIST手写数字图像分类)的内容。其中提到了MNIST数据集,在那篇博文中,是使用以下代码调用torchvision库中的MNIST数据集并完成训练的。
# MNIST数据集
# 训练集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# 验证集
val_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
# 测试集
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
运行此代码便会自动将MNIST数据集下载到./data
目录下
/MNIST
---/processed
------/training.pt
------/test.pt
---/raw
------/t10k-images-idx3-ubyte.gz
------/t10k-labels-idx1-ubyte.gz
------/train-images-idx3-ubyte.gz
------/train-labels-idx1-ubyte.gz
其中包含了2个权重文件和4个.gz
格式的二进制数据压缩文件,images数据中每个图像以28x28的大小存储,labels数据中每个标签是一个0到9的整数。可以使用 torchvision.utils.save_image(image, image_path)
指令将数据集解析为图像+标签的格式。
模型的保存与读取
模型的保存有2种方式,一种是保存模型的状态字典(即模型的权重),另一种是保存整个模型(包括模型架构和权重)。可以使用以下两行代码实现模型的保存与读取:
# 创建一个新的模型实例
model = NeuralNetwork()
# 保存模型
torch.save(model.state_dict(), '文件名')
torch.save(model, '文件名')
# 加载状态字典
model.load_state_dict(torch.load('文件名'))
# 加载整个模型
model = torch.load('文件名')
loss和accuracy图像绘制
在训练结束后绘制图像
图像可以更加形象地展示训练的过程,在训练过程中时刻记录loss和accuracy的值,以便在训练结束后绘制图像。
def plot_loss_accuracy(train_losses, accuracies):
clear_output(wait=True)
# 绘制损失曲线
plt.subplot(1, 2, 1)
plt.plot(range(1, len(train_losses)+1), train_losses, label='Train Loss', color='blue')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Train Loss')
plt.legend()
plt.grid(True) # 启用网格线
# 绘制准确率曲线
plt.subplot(1, 2, 2)
plt.plot(range(1, len(accuracies)+1), accuracies, label='Test Accuracy', color='red')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.title('Test Accuracy')
plt.legend()
plt.grid(True) # 启用网格线
plt.tight_layout()
# 保存图像
plt.savefig(f'plot.jpg', dpi=600) # 设置dpi保证清晰度
plt.show() # 显示图像
最终训练完成后保存的图像如下:
在训练过程中实时绘制图像
如果想要在训练过程中实时绘制图像,只需要在迭代循环中加入以下代码即可:
# 在循环开始前创建画布
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(tqdm(train_loader)):
..........
# 训练代码
..........
# 实时更新绘制曲线
axs[0].clear()
axs[0].plot(train_losses, label='Train Loss')
axs[0].set_xlabel('Epoch')
axs[0].set_ylabel('Loss')
axs[0].legend()
axs[1].clear()
axs[1].plot(accuracies, label='Test Accuracy', color='green')
axs[1].set_xlabel('Epoch')
axs[1].set_ylabel('Accuracy (%)')
axs[1].legend()
plt.tight_layout()
plt.draw()
plt.pause(0.001)
这样就可以在训练的过程中产生一个交互式实时显示页面:
模型数据统计
一般统计模型信息包含模型的参数量和模型的FLOPs
。模型的参数量可以反映模型的复杂度,较多的参数可能意味着模型可以拟合更复杂的函数,但也可能导致过拟合问题,进而反应模型的性能。FLOPs
(浮点运算数)是衡量模型在推理阶段所需计算资源的重要指标,它反映了模型进行前向传播时所执行的浮点运算数量,通过分析模型的FLOPs
,可以识别模型中的计算密集型部分。
在实例化模型后加入以下代码即可得到以上统计信息:
model = NeuralNetwork()
total_params = sum(p.numel() for p in net.parameters())
print("模型参数: ", total_params)
flops_count = 2 * sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"FLOPs: {flops_count:,}")
根据"基于pytorch的深度学习的MNIST手写数字图像分类"中的案例,运行上述代码后,其模型的相关统计信息为:
模型参数: 235146
FLOPs: 470,292
可视化测试结果
在训练完成后,可以使用测试数据对训练的模型进行测试,除了获得使用评价指标进行评估的统计数据之外,还可以输入一些图像使用模型测试获得可视化示例,这样可以更加直观地展示模型的测试效果。
首先载入模型:
model = NeuralNetwork() # 实例化网络
model.load_state_dict(torch.load('checkpoints/final_dict_model.pth')) # 载入模型
model.eval() # 开启评估模式
载入数据,并保持数据预处理的形式与训练时一致:
# 数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# 载入测试集
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
选择待测试和可视化的图像:
# 从测试集中随机选择20张图像
selected_samples = random.sample(range(len(test_dataset)), 20)
评估测试图像:
for i, idx in enumerate(selected_samples):
image, label = test_dataset[idx]
image = image.reshape(-1, 28*28)
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output.data, 1)
print("测试第{}张图像完毕!".format(i+1))
可视化结果:
for i, idx in enumerate(selected_samples):
......
# 测试评估代码
......
# 将单张图像添加到大图中
row = i // 5
col = i % 5
axs[row, col].imshow(image[0].numpy().reshape(28, 28), cmap='gray')
axs[row, col].set_title(f'Predicted: {predicted.item()}, True Label: {label}', fontsize=9)
axs[row, col].axis('off')
plt.show()
最终的结果为:
关于我
欢迎与我交流!
公众号:AI小火车