基于PyTorch和ResNet18的花卉识别实战(附完整代码)

一、项目背景与效果

花卉分类是计算机视觉的经典任务。本文使用PyTorch框架,基于ResNet18模型实现了102种花卉的分类任务。完整代码可直接复制运行,文中同步分析性能瓶颈与优化方案。

二、环境配置与数据准备

1. 环境要求

# 主要依赖库
import torch
from torch import nn, optim
from torchvision import transforms, datasets, models
import matplotlib.pyplot as plt
import numpy as np
import json

2. 数据集结构

data/
├── train/         # 训练集(6489张)
│   ├── 1/        # 类别编号文件夹
│   ├── 2/
│   └── ... 
├── valid/         # 验证集(1700张)
└── flower_names.json  # 类别映射文件

3. 类别映射文件

# flower_names.json 示例
{
  "1": "玫瑰",
  "2": "郁金香",
  ...,
  "102": "墙藓"
}

三、完整代码实现

1. 数据加载与增强

# 定义数据路径
data_dir = 'D:/python_text/python/花卉识别/data'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

# 数据增强与预处理
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([96, 96]),
        transforms.RandomRotation(45),
        transforms.CenterCrop(64),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.RandomGrayscale(p=0.025),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize([64, 64]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# 创建数据集
batch_size = 512
train_dataset = datasets.ImageFolder(train_dir, transform=data_transforms['train'])
valid_dataset = datasets.ImageFolder(valid_dir, transform=data_transforms['valid'])

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

2. 模型定义(修改ResNet18)

class ResNet18Model(nn.Module):
    def __init__(self, num_classes=102, pretrained=True):
        super().__init__()
        self.base_model = models.resnet18(pretrained=pretrained)
        num_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        return self.base_model(x)

# 初始化模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ResNet18Model().to(device)

3. 训练配置

# 定义损失函数与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 需显式设置学习率
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# 加载类别名称
with open('D:/python_text/python/花卉识别/data/flower_names.json', 'r') as f:
    flower_names = json.load(f)

4. 训练与验证循环

num_epochs = 25
for epoch in range(num_epochs):
    # 训练阶段
    model.train()
    train_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * images.size(0)
    
    # 验证阶段
    model.eval()
    valid_loss = 0.0
    correct = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            valid_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
    
    # 打印统计信息
    train_loss = train_loss / len(train_dataset)
    valid_loss = valid_loss / len(valid_dataset)
    valid_acc = correct / len(valid_dataset)
    
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Train Loss: {train_loss:.4f} | Val Loss: {valid_loss:.4f} | Val Acc: {valid_acc:.4f}')

四、优化方案

1. 代码改进建议

# 优化点示例:减小批次大小 + 学习率预热
batch_size = 256  # 原512
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)  # 添加L2正则

# 添加学习率预热
from torch.optim.lr_scheduler import LinearLR
warmup_scheduler = LinearLR(optimizer, start_factor=0.01, total_iters=5)

2. 其他优化方向

  • 数据层面

    • 使用公开数据集(如Oxford-102 Flowers)扩充数据

    • 添加针对性数据增强(花瓣局部裁剪、光照模拟)

  • 模型层面

    • 更换为ResNet50/ViT模型

    • 添加注意力机制模块

  • 训练技巧

    • 使用标签平滑(Label Smoothing)

    • 引入Focal Loss解决类别不平衡

五、总结

本文提供了可直接运行的花卉分类完整代码,并针对低准确率问题提出了改进方向。关键点:

  1. 数据增强需符合花卉特征(避免过度旋转)

  2. 合理设置超参数(批次大小、学习率)

  3. 复杂场景建议使用更先进的模型架构

代码可直接复制到本地,修改数据集路径后运行。建议尝试添加Grad-CAM可视化模块,深入分析模型决策依据。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值