【深度学习实战】当前三个最佳图像分类模型的代码详解

下面给出三个在当前图像分类任务中精度表现突出的模型示例,分别基于 Swin TransformerEfficientNetConvNeXt。每个模型均包含:

  1. 训练代码(使用 PyTorch)

    • 从预训练权重开始微调(也可注释掉预训练选项,从头训练)

    • 数据集目录结构:

      └── dataset_root
          ├── buy        # 第一类图像
          └── nobuy      # 第二类图像
      
    • 随机拆分:80% 训练,20% 验证

    • 每个 Epoch 输出一次 loss

    • 当连续 10 个 Epoch 验证集上 loss 不再降低,提前停止训练(Early Stopping),并保存模型

  2. 测试代码

    • 加载已保存的模型

    • 对单张图片或整个文件夹中的图片进行预测,输出属于两类的概率值

以下示例依赖的库包括:

  • torch, torchvision, timm(如果你要使用 timm.create_model 来构建 Swin / EfficientNet / ConvNeXt)

  • 或者使用 torchvision.models 中内置的对应模型(如 torchvision.models.convnext_tiny, torchvision.models.efficientnet_b0 等)。

如果尚未安装 timm,可执行:

pip install timm

一、Swin Transformer

1. 训练代码

import os
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
from torchvision import datasets, transforms

class EarlyStopping:
    """
    当验证集上的 loss 在连续 patience 个 epoch 不再改善时触发停止。
    """
    def __init__(self, patience=10, verbose=False):
        self.patience = patience
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.verbose = verbose

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping Counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True

def train_swin_transformer(
    data_root="dataset_root",
    model_save_path="swin_transformer_best.pth",
    batch_size=8,
    num_epochs=50,
    patience=10,
    lr=1e-4
):
    """
    使用 Swin Transformer 模型进行二分类训练示例。
    :param data_root: 数据集根目录,下有 'buy', 'nobuy' 两个子文件夹
    :param model_save_path: 训练完成后模型保存的路径
    :param batch_size: 批量大小
    :param num_epochs: 最大训练 epoch 数
    :param patience: 早停的等待轮数,当验证损失连续 10 个 epoch 不下降则停止
    :param lr: 学习率
    """
    
    # 1. 定义图像转换
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Swin 默认224分辨率
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet 预训练标准化
    ])

    # 2. 使用 ImageFolder 读取数据
    dataset = datasets.ImageFolder(root=data_root, transform=transform)
    class_names = dataset.classes  # ['buy', 'nobuy'],顺序与目录名相关

    # 3. 拆分训练集(80%)、验证集(20%)
    dataset_size = len(dataset)
    val_size = int(dataset_size * 0.2)
    train_size = dataset_size - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # 4. 创建 DataLoader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # 5. 构建 Swin Transformer 模型,num_classes=2
    model = timm.create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=2)

    # 6. 定义损失函数与优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # 7. Early Stopping
    early_stopping = EarlyStopping(patience=patience, verbose=True)

    best_val_loss = float("inf")

    # 8. 训练循环
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        epoch_train_loss = running_loss / len(train_loader.dataset)

        # 在验证集上评估
        model.eval()
        val_running_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_running_loss += loss.item() * images.size(0)

        epoch_val_loss = val_running_loss / len(val_loader.dataset)

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")

        # 检查是否是最优,如果是则保存
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            torch.save(model.state_dict(), model_save_path)
            print("  -> Model improved; saving current model.")

        # EarlyStopping 检测
        early_stopping(epoch_val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered!")
            break

    print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))
    print(f"Model is saved to: {model_save_path}")

if __name__ == "__main__":
    train_swin_transformer()
代码说明
  • 使用 timm.create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=2) 构建一个 Swin Transformer tiny 版本模型,加载预训练权重并将输出层改为2类。

  • ImageFolder 读取 buy, nobuy 两个文件夹的数据,自动识别类别标签。通过 random_split 以 80:20 划分训练与验证集。

  • 训练时每个 epoch 会输出一次训练集 loss、验证集 loss,若验证集上的 loss 在 10 个 epoch 内未改善,即停止训练。

  • 最优模型权重会被保存在 model_save_path 文件。


2. 测试(推理)代码

import os
import timm
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

def predict_swin_transformer(
    model_path="swin_transformer_best.pth",
    target="test_image_or_dir",
    class_names=("buy","nobuy")
):
    """
    加载训练好的 Swin Transformer 模型,对单张图像或目录下所有图像进行预测。
    :param model_path: 训练时保存的模型权重路径
    :param target: 可以是单张图片路径,也可以是包含多张图片的目录路径
    :param class_names: 类别名称,需与训练时顺序一致
    """
    # 1. 加载模型
    model = timm.create_model("swin_tiny_patch4_window7_224", pretrained=False, num_classes=2)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # 2. 定义与训练时相同的图像转换
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    def infer_image(img_path):
        img = Image.open(img_path).convert("RGB")
        input_tensor = transform(img).unsqueeze(0)
        with torch.no_grad():
            outputs = model(input_tensor)
            probs = F.softmax(outputs, dim=1).numpy().flatten()
        pred_idx = probs.argmax()
        return pred_idx, probs

    # 判断是单张图片还是文件夹
    if os.path.isfile(target):
        # 单张图片
        idx, prob = infer_image(target)
        print(f"Image: {target}")
        print(f" -> Predicted: {class_names[idx]}, Probability: {prob[idx]:.4f}")
    elif os.path.isdir(target):
        # 目录下所有图片
        for file_name in os.listdir(target):
            file_path = os.path.join(target, file_name)
            if os.path.isfile(file_path):
                idx, prob = infer_image(file_path)
                print(f"Image: {file_path}")
                print(f" -> Predicted: {class_names[idx]}, Probability: {prob[idx]:.4f}")
    else:
        print("Error: target path is neither a file nor a directory.")

if __name__ == "__main__":
    predict_swin_transformer(
        model_path="swin_transformer_best.pth",
        target="test_image_or_dir",  # 改成想预测的路径
        class_names=("buy", "nobuy")
    )
预测说明
  • 通过 model.eval() 切换到推理模式,确保不会更新参数。

  • 对于每个图像,输出属于两类的概率,并根据最大概率所在索引确定预测类别。

  • 可以指定 target 为具体文件,或者一个文件夹。


二、EfficientNet

1. 训练代码

下面演示使用 timm 中的 EfficientNet-B0 模型进行二分类的训练流程。其他版本(B1-B7)只需在 timm.create_model("efficientnet_b0") 中更改模型名称即可。

import os
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

class EarlyStopping:
    def __init__(self, patience=10, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping Counter: {self.counter} / {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True

def train_efficientnet(
    data_root="dataset_root",
    model_save_path="efficientnet_best.pth",
    batch_size=8,
    num_epochs=50,
    patience=10,
    lr=1e-4
):
    """
    使用 EfficientNet-B0 模型进行二分类训练示例。
    """
    # 数据预处理
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
    dataset = datasets.ImageFolder(root=data_root, transform=transform)

    # 拆分 80/20
    dataset_size = len(dataset)
    val_size = int(dataset_size * 0.2)
    train_size = dataset_size - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # 模型: EfficientNet-B0, 2分类
    model = timm.create_model("efficientnet_b0", pretrained=True, num_classes=2)

    # 损失与优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    early_stopping = EarlyStopping(patience=patience, verbose=True)

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        epoch_train_loss = running_loss / len(train_loader.dataset)

        # 验证
        model.eval()
        val_running_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_running_loss += loss.item() * images.size(0)

        epoch_val_loss = val_running_loss / len(val_loader.dataset)

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")

        # 保存最好模型
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            torch.save(model.state_dict(), model_save_path)
            print("  -> Model improved; saving current model.")

        # 早停判断
        early_stopping(epoch_val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered!")
            break

    print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))
    print(f"Model is saved to: {model_save_path}")

if __name__ == "__main__":
    train_efficientnet()

2. 测试(推理)代码

import os
import timm
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

def predict_efficientnet(
    model_path="efficientnet_best.pth",
    target="test_image_or_dir",
    class_names=("buy","nobuy")
):
    # 1. 创建模型 & 加载权重
    model = timm.create_model("efficientnet_b0", pretrained=False, num_classes=2)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # 2. 预处理
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    def infer_image(img_path):
        img = Image.open(img_path).convert("RGB")
        input_tensor = transform(img).unsqueeze(0)
        with torch.no_grad():
            outputs = model(input_tensor)
            probs = F.softmax(outputs, dim=1).numpy().flatten()
        pred_idx = probs.argmax()
        return pred_idx, probs

    if os.path.isfile(target):
        # 单图片
        idx, prob = infer_image(target)
        print(f"Image: {target}")
        print(f" -> Predicted: {class_names[idx]}, Probability: {prob[idx]:.4f}")
    elif os.path.isdir(target):
        # 文件夹
        for file_name in os.listdir(target):
            file_path = os.path.join(target, file_name)
            if os.path.isfile(file_path):
                idx, prob = infer_image(file_path)
                print(f"Image: {file_path}")
                print(f" -> Predicted: {class_names[idx]}, Probability: {prob[idx]:.4f}")
    else:
        print("Error: target path is neither a file nor a directory.")

if __name__ == "__main__":
    predict_efficientnet(
        model_path="efficientnet_best.pth",
        target="test_image_or_dir",
        class_names=("buy", "nobuy")
    )

三、ConvNeXt

1. 训练代码

import os
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

class EarlyStopping:
    def __init__(self, patience=10, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping Counter: {self.counter} / {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True

def train_convnext(
    data_root="dataset_root",
    model_save_path="convnext_best.pth",
    batch_size=8,
    num_epochs=50,
    patience=10,
    lr=1e-4
):
    """
    使用 ConvNeXt Tiny 模型进行二分类训练示例。
    """
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
    dataset = datasets.ImageFolder(root=data_root, transform=transform)

    # 80/20 拆分
    dataset_size = len(dataset)
    val_size = int(dataset_size * 0.2)
    train_size = dataset_size - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # 使用 timm 中的 convnext_tiny, 2分类
    model = timm.create_model("convnext_tiny", pretrained=True, num_classes=2)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    early_stopping = EarlyStopping(patience=patience, verbose=True)

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        epoch_train_loss = running_loss / len(train_loader.dataset)

        # 验证
        model.eval()
        val_running_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_running_loss += loss.item() * images.size(0)

        epoch_val_loss = val_running_loss / len(val_loader.dataset)

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")

        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            torch.save(model.state_dict(), model_save_path)
            print("  -> Model improved; saving current model.")

        early_stopping(epoch_val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered!")
            break

    print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))
    print(f"Model is saved to: {model_save_path}")

if __name__ == "__main__":
    train_convnext()

2. 测试(推理)代码

import os
import timm
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

def predict_convnext(
    model_path="convnext_best.pth",
    target="test_image_or_dir",
    class_names=("buy","nobuy")
):
    # 1. 创建模型 & 加载权重
    model = timm.create_model("convnext_tiny", pretrained=False, num_classes=2)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # 2. 与训练一致的预处理
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    def infer_image(img_path):
        img = Image.open(img_path).convert("RGB")
        input_tensor = transform(img).unsqueeze(0)
        with torch.no_grad():
            outputs = model(input_tensor)
            probs = F.softmax(outputs, dim=1).numpy().flatten()
        pred_idx = probs.argmax()
        return pred_idx, probs

    if os.path.isfile(target):
        idx, prob = infer_image(target)
        print(f"Image: {target}")
        print(f" -> Predicted: {class_names[idx]}, Probability: {prob[idx]:.4f}")
    elif os.path.isdir(target):
        for file_name in os.listdir(target):
            file_path = os.path.join(target, file_name)
            if os.path.isfile(file_path):
                idx, prob = infer_image(file_path)
                print(f"Image: {file_path}")
                print(f" -> Predicted: {class_names[idx]}, Probability: {prob[idx]:.4f}")
    else:
        print("Error: target path is neither a file nor a directory.")

if __name__ == "__main__":
    predict_convnext(
        model_path="convnext_best.pth",
        target="test_image_or_dir",
        class_names=("buy", "nobuy")
    )

总结与注意事项

  1. 模型选择:给出了 Swin Transformer(基于视觉Transformer)、EfficientNet(高效卷积网络)与 ConvNeXt(现代卷积网络)的PyTorch实现案例。它们均可在小数据集上通过微调预训练权重获得较好的精度。

  2. 数据准备

    • 需要将图片放在 dataset_root/buy/dataset_root/nobuy/,分别代表两类。

    • 代码中 train_*() 函数会自动使用 ImageFolder 读取并做 80:20 的训练/验证拆分。

  3. 早停策略:如果验证集的 loss 在 patience=10 个 epoch 内不再改善,脚本会停止训练并保留最后一次的最优模型。

  4. 保存与加载:训练完成后,会将最优模型(基于最低验证 loss)保存在指定 .pth 文件中。推理脚本会加载该权重用于评估。

  5. 测试脚本:支持对单张图片目录下所有图片批量预测,输出二分类概率。

  6. 从头训练:若想从头训练,可在 timm.create_model(..., pretrained=False, ...) 中将 pretrained 设为 False,但通常不建议在仅千级数据时从零开始,会引起过拟合,精度不如微调法。

  7. 扩展:如果需要更多的正则化、数据增强(如随机裁剪、随机水平翻转、AutoAugment等),可在 transforms 中添加相应操作,以进一步提升模型在小数据集上的泛化能力。

以上示例可帮你分别训练与测试 3 种优秀的图像二分类模型,满足你在 精度优先80%训练+20%验证早停保存模型 以及 推理阶段输出概率 等需求。若有更多个性化需求(如多卡训练、学习率调度、混合精度训练等),可在此基础上进行拓展。祝你实验顺利!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值