MNIST图像分类任务功能性代码补充

数据集说明

接上一篇博客(基于pytorch的深度学习的MNIST手写数字图像分类)的内容。其中提到了MNIST数据集,在那篇博文中,是使用以下代码调用torchvision库中的MNIST数据集并完成训练的。

# MNIST数据集
# 训练集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# 验证集
val_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

# 测试集
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

运行此代码便会自动将MNIST数据集下载到./data目录下

/MNIST
---/processed
------/training.pt
------/test.pt
---/raw
------/t10k-images-idx3-ubyte.gz
------/t10k-labels-idx1-ubyte.gz
------/train-images-idx3-ubyte.gz
------/train-labels-idx1-ubyte.gz

其中包含了2个权重文件和4个.gz格式的二进制数据压缩文件,images数据中每个图像以28x28的大小存储,labels数据中每个标签是一个0到9的整数。可以使用 torchvision.utils.save_image(image, image_path)指令将数据集解析为图像+标签的格式。

模型的保存与读取

模型的保存有2种方式,一种是保存模型的状态字典(即模型的权重),另一种是保存整个模型(包括模型架构和权重)。可以使用以下两行代码实现模型的保存与读取:

# 创建一个新的模型实例
model = NeuralNetwork()
# 保存模型
torch.save(model.state_dict(), '文件名')
torch.save(model, '文件名')

# 加载状态字典
model.load_state_dict(torch.load('文件名'))
# 加载整个模型
model = torch.load('文件名')

loss和accuracy图像绘制

在训练结束后绘制图像

图像可以更加形象地展示训练的过程,在训练过程中时刻记录loss和accuracy的值,以便在训练结束后绘制图像。

def plot_loss_accuracy(train_losses, accuracies):
    clear_output(wait=True)

    # 绘制损失曲线
    plt.subplot(1, 2, 1)
    plt.plot(range(1, len(train_losses)+1), train_losses, label='Train Loss', color='blue')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Train Loss')
    plt.legend()
    plt.grid(True)  # 启用网格线

    # 绘制准确率曲线
    plt.subplot(1, 2, 2)
    plt.plot(range(1, len(accuracies)+1), accuracies, label='Test Accuracy', color='red')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.title('Test Accuracy')
    plt.legend()
    plt.grid(True)  # 启用网格线

    plt.tight_layout()

    # 保存图像
    plt.savefig(f'plot.jpg', dpi=600) # 设置dpi保证清晰度
    plt.show() # 显示图像

最终训练完成后保存的图像如下:
请添加图片描述

在训练过程中实时绘制图像

如果想要在训练过程中实时绘制图像,只需要在迭代循环中加入以下代码即可:

# 在循环开始前创建画布
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(tqdm(train_loader)):
    ..........
    # 训练代码
    ..........
	# 实时更新绘制曲线
	axs[0].clear()
	axs[0].plot(train_losses, label='Train Loss')
	axs[0].set_xlabel('Epoch')
	axs[0].set_ylabel('Loss')
	axs[0].legend()
	
	axs[1].clear()
	axs[1].plot(accuracies, label='Test Accuracy', color='green')
	axs[1].set_xlabel('Epoch')
	axs[1].set_ylabel('Accuracy (%)')
	axs[1].legend()
	
	plt.tight_layout()
	plt.draw()
	plt.pause(0.001)

这样就可以在训练的过程中产生一个交互式实时显示页面:
在这里插入图片描述

模型数据统计

一般统计模型信息包含模型的参数量和模型的FLOPs。模型的参数量可以反映模型的复杂度,较多的参数可能意味着模型可以拟合更复杂的函数,但也可能导致过拟合问题,进而反应模型的性能。FLOPs(浮点运算数)是衡量模型在推理阶段所需计算资源的重要指标,它反映了模型进行前向传播时所执行的浮点运算数量,通过分析模型的FLOPs,可以识别模型中的计算密集型部分。

在实例化模型后加入以下代码即可得到以上统计信息:

model = NeuralNetwork()
total_params = sum(p.numel() for p in net.parameters())
print("模型参数: ", total_params)
flops_count = 2 * sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"FLOPs: {flops_count:,}")

根据"基于pytorch的深度学习的MNIST手写数字图像分类"中的案例,运行上述代码后,其模型的相关统计信息为:

模型参数: 235146
FLOPs: 470,292

可视化测试结果

在训练完成后,可以使用测试数据对训练的模型进行测试,除了获得使用评价指标进行评估的统计数据之外,还可以输入一些图像使用模型测试获得可视化示例,这样可以更加直观地展示模型的测试效果。

首先载入模型:

model = NeuralNetwork() # 实例化网络
model.load_state_dict(torch.load('checkpoints/final_dict_model.pth')) # 载入模型
model.eval() # 开启评估模式

载入数据,并保持数据预处理的形式与训练时一致:

# 数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# 载入测试集
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

选择待测试和可视化的图像:

# 从测试集中随机选择20张图像
selected_samples = random.sample(range(len(test_dataset)), 20)

评估测试图像:

for i, idx in enumerate(selected_samples):
    image, label = test_dataset[idx]
    image = image.reshape(-1, 28*28)
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output.data, 1)
    
    print("测试第{}张图像完毕!".format(i+1))

可视化结果:

for i, idx in enumerate(selected_samples):
	......
	# 测试评估代码
	......
    # 将单张图像添加到大图中
    row = i // 5
    col = i % 5
    axs[row, col].imshow(image[0].numpy().reshape(28, 28), cmap='gray')
    axs[row, col].set_title(f'Predicted: {predicted.item()}, True Label: {label}', fontsize=9)
    axs[row, col].axis('off')
plt.show()

最终的结果为:
请添加图片描述

关于我

欢迎与我交流!
公众号:AI小火车
Alt

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值