一、项目背景与效果
花卉分类是计算机视觉的经典任务。本文使用PyTorch框架,基于ResNet18模型实现了102种花卉的分类任务。完整代码可直接复制运行,文中同步分析性能瓶颈与优化方案。
二、环境配置与数据准备
1. 环境要求
# 主要依赖库
import torch
from torch import nn, optim
from torchvision import transforms, datasets, models
import matplotlib.pyplot as plt
import numpy as np
import json
2. 数据集结构
data/
├── train/ # 训练集(6489张)
│ ├── 1/ # 类别编号文件夹
│ ├── 2/
│ └── ...
├── valid/ # 验证集(1700张)
└── flower_names.json # 类别映射文件
3. 类别映射文件
# flower_names.json 示例
{
"1": "玫瑰",
"2": "郁金香",
...,
"102": "墙藓"
}
三、完整代码实现
1. 数据加载与增强
# 定义数据路径
data_dir = 'D:/python_text/python/花卉识别/data'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
# 数据增强与预处理
data_transforms = {
'train': transforms.Compose([
transforms.Resize([96, 96]),
transforms.RandomRotation(45),
transforms.CenterCrop(64),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
transforms.RandomGrayscale(p=0.025),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'valid': transforms.Compose([
transforms.Resize([64, 64]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
# 创建数据集
batch_size = 512
train_dataset = datasets.ImageFolder(train_dir, transform=data_transforms['train'])
valid_dataset = datasets.ImageFolder(valid_dir, transform=data_transforms['valid'])
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
2. 模型定义(修改ResNet18)
class ResNet18Model(nn.Module):
def __init__(self, num_classes=102, pretrained=True):
super().__init__()
self.base_model = models.resnet18(pretrained=pretrained)
num_features = self.base_model.fc.in_features
self.base_model.fc = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(num_features, 512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, num_classes)
)
def forward(self, x):
return self.base_model(x)
# 初始化模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ResNet18Model().to(device)
3. 训练配置
# 定义损失函数与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) # 需显式设置学习率
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
# 加载类别名称
with open('D:/python_text/python/花卉识别/data/flower_names.json', 'r') as f:
flower_names = json.load(f)
4. 训练与验证循环
num_epochs = 25
for epoch in range(num_epochs):
# 训练阶段
model.train()
train_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
# 验证阶段
model.eval()
valid_loss = 0.0
correct = 0
with torch.no_grad():
for images, labels in valid_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
valid_loss += loss.item() * images.size(0)
_, preds = torch.max(outputs, 1)
correct += (preds == labels).sum().item()
# 打印统计信息
train_loss = train_loss / len(train_dataset)
valid_loss = valid_loss / len(valid_dataset)
valid_acc = correct / len(valid_dataset)
print(f'Epoch {epoch+1}/{num_epochs}')
print(f'Train Loss: {train_loss:.4f} | Val Loss: {valid_loss:.4f} | Val Acc: {valid_acc:.4f}')
四、优化方案
1. 代码改进建议
# 优化点示例:减小批次大小 + 学习率预热
batch_size = 256 # 原512
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4) # 添加L2正则
# 添加学习率预热
from torch.optim.lr_scheduler import LinearLR
warmup_scheduler = LinearLR(optimizer, start_factor=0.01, total_iters=5)
2. 其他优化方向
-
数据层面:
-
使用公开数据集(如Oxford-102 Flowers)扩充数据
-
添加针对性数据增强(花瓣局部裁剪、光照模拟)
-
-
模型层面:
-
更换为ResNet50/ViT模型
-
添加注意力机制模块
-
-
训练技巧:
-
使用标签平滑(Label Smoothing)
-
引入Focal Loss解决类别不平衡
-
五、总结
本文提供了可直接运行的花卉分类完整代码,并针对低准确率问题提出了改进方向。关键点:
-
数据增强需符合花卉特征(避免过度旋转)
-
合理设置超参数(批次大小、学习率)
-
复杂场景建议使用更先进的模型架构
代码可直接复制到本地,修改数据集路径后运行。建议尝试添加Grad-CAM可视化模块,深入分析模型决策依据。