EfficientNet-v2-s图像分类训练(简洁版)

使用torchvision集成的efficientnet-v2-s模型,调用torchvision库中的Oxford IIIT Pet数据集,对模型进行训练。
若有修改要求,可以修改以下部分:

train_dataset = OxfordIIITPet(root='./data', split='trainval', download=True, transform=transform_train)
test_dataset = OxfordIIITPet(root='./data', split='test', download=True, transform=transform_test)
#常见数据集可以直接加载,若是自己的数据集就自己写个dataset/dataloader
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 37)
#37为数据集类别数,修改为自己对应的
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=3, factor=0.1, verbose=True)
#学习率处可以自己调整,可玩性较高

训练截图:
在这里插入图片描述

其实十轮左右就稳定在90以上了,跑了三十轮,记得修改保存路径,我这里是用kaggle跑的。
代码如下:

import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.models import efficientnet_v2_s
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import OxfordIIITPet
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

# 数据预处理 + 数据增强
transform_train = transforms.Compose([
    transforms.Resize((256, 256)),  # 增大图片预处理尺寸
    transforms.RandomCrop((224, 224)),  # 随机裁剪到模型输入尺寸
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据集
train_dataset = OxfordIIITPet(root='./data', split='trainval', download=True, transform=transform_train)
test_dataset = OxfordIIITPet(root='./data', split='test', download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 模型定义
model = efficientnet_v2_s(pretrained=True)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 37)

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=3, factor=0.1, verbose=True)

# 训练模型
def train_model(num_epochs):
    model.train()
    best_accuracy = 0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader, desc=f'Training Epoch {epoch + 1}'):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # 每个 epoch 后测试
        accuracy = test_model()
        scheduler.step(accuracy)
        
        # 如果当前模型表现更好,保存模型
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), '/kaggle/working/best_oxford_pets_efficientnetv2.pth')
            print(f'New best model saved with accuracy: {best_accuracy:.2f}%')

def test_model():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc='Testing'):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # 调试输出
            if total < 50:  # 只打印前50个样本的信息
                print(f'Predicted: {predicted[:10]}, Labels: {labels[:10]}')

    accuracy = 100 * correct / total
    print(f'Testing Accuracy: {accuracy:.2f}%')
    return accuracy



train_model(30)

  • 4
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值