一、项目介绍
项目概述
这是一个使用PyTorch框架的深度学习项目,目的是构建一个能够区分蜜蜂和蚂蚁的图像分类模型。项目使用了VGG16作为基础模型,并采用了迁移学习的方法来提高模型的性能。
技术栈
- PyTorch: 主要使用的深度学习框架。
- Torchvision: 提供了用于图像处理的常用工具,包括数据加载和转换。
- Matplotlib: 用于绘制训练过程中的损失和准确率曲线。
- tqdm: 进度条显示库,用于监控训练进度。
数据准备
- 数据集路径:
- 训练数据集路径:
I:\code\pytorch\VGG\datasets\train
- 验证数据集路径:
I:\code\pytorch\VGG\datasets\val
- 训练数据集路径:
- 数据增强:
- 将所有图像调整为224x224像素。
- 转换为张量。
- 使用ImageNet的均值和标准差进行标准化。
模型定义
- VGG16:
- 类别数 (
num_classes
): 2(蜜蜂和蚂蚁)。 - 预训练 (
pretrained
): 可选择使用或不使用ImageNet上的预训练权重。 - 模型实例化后被移动到指定的设备(CPU或GPU)。
- 类别数 (
训练配置
- 超参数:
- 批大小 (
batch_size
): 4。 - 学习率 (
learning_rate
): 0.001。 - 训练轮数 (
num_epochs
): 30。
- 批大小 (
- 优化器:
- 使用随机梯度下降 (SGD) 优化器。
- 动量设置为0.9。
- 学习率调度器:
- 使用StepLR策略,在每7个epoch后将学习率乘以0.1。
训练过程
- 训练循环:
- 在每个epoch中,模型首先在训练集上进行训练,然后在验证集上进行评估。
- 使用交叉熵损失 (
CrossEntropyLoss
) 作为损失函数。 - 在训练过程中记录训练和验证损失及准确率。
- 每个epoch结束时,保存当前模型状态。
- 可视化:
- 训练完成后,绘制训练和验证的损失、准确率以及学习率的变化曲线,并保存到指定文件夹。
文件结构
- 输出文件夹 (
output_dir
): 保存训练好的模型、损失/准确率曲线图和学习率曲线图。
执行
- 主函数 (
main
):- 接受一个布尔参数
use_pretrained
来决定是否使用预训练权重。 - 调用
train_model
函数进行模型训练。
- 接受一个布尔参数
注意事项
- 确保指定的路径正确无误。
- 如果使用GPU,确保安装了正确的PyTorch版本以及CUDA驱动。
- 可能需要根据实际情况调整超参数以获得最佳性能。
二、数据集介绍
数据集之前介绍过,在此不再赘述。
三、完整代码
训练代码:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotl