下面给出三个在当前图像分类任务中精度表现突出的模型示例,分别基于 Swin Transformer、EfficientNet 与 ConvNeXt。每个模型均包含:
-
训练代码(使用 PyTorch)
-
从预训练权重开始微调(也可注释掉预训练选项,从头训练)
-
数据集目录结构:
└── dataset_root ├── buy # 第一类图像 └── nobuy # 第二类图像
-
随机拆分:80% 训练,20% 验证
-
每个 Epoch 输出一次 loss
-
当连续 10 个 Epoch 验证集上 loss 不再降低,提前停止训练(Early Stopping),并保存模型
-
-
测试代码
-
加载已保存的模型
-
对单张图片或整个文件夹中的图片进行预测,输出属于两类的概率值
-
以下示例依赖的库包括:
-
torch
,torchvision
,timm
(如果你要使用timm.create_model
来构建 Swin / EfficientNet / ConvNeXt) -
或者使用
torchvision.models
中内置的对应模型(如torchvision.models.convnext_tiny
,torchvision.models.efficientnet_b0
等)。
如果尚未安装 timm
,可执行:
pip install timm
一、Swin Transformer
1. 训练代码
import os
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
from torchvision import datasets, transforms
class EarlyStopping:
"""
当验证集上的 loss 在连续 patience 个 epoch 不再改善时触发停止。
"""
def __init__(self, patience=10, verbose=False):
self.patience = patience
self.counter = 0
self.best_loss = None
self.early_stop = False
self.verbose = verbose
def __call__(self, val_loss):
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss < self.best_loss:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
if self.verbose:
print(f"EarlyStopping Counter: {self.counter} out of {self.patience}")
if self.counter >= self.patience:
self.early_stop = True
def train_swin_transformer(
data_root="dataset_root",
model_save_path="swin_transformer_best.pth",
batch_size=8,
num_epochs=50,
patience=10,
lr=1e-4
):
"""
使用 Swin Transformer 模型进行二分类训练示例。
:param data_root: 数据集根目录,下有 'buy', 'nobuy' 两个子文件夹
:param model_save_path: 训练完成后模型保存的路径
:param batch_size: 批量大小
:param num_epochs: 最大训练 epoch 数
:param patience: 早停的等待轮数,当验证损失连续 10 个 epoch 不下降则停止
:param lr: 学习率
"""
# 1. 定义图像转换
transform = transforms.Compose([
transforms.Resize((224, 224)), # Swin 默认224分辨率
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet 预训练标准化
])
# 2. 使用 ImageFolder 读取数据
dataset = datasets.ImageFolder(root=data_root, transform=transform)
class_names = dataset.classes # ['buy', 'nobuy'],顺序与目录名相关
# 3. 拆分训练集(80%)、验证集(20%)
dataset_size = len(dataset)
val_size = int(dataset_size * 0.2)
train_size = dataset_size - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# 4. 创建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# 5. 构建 Swin Transformer 模型,num_classes=2
model = timm.create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=2)
# 6. 定义损失函数与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
# 7. Early Stopping
early_stopping = EarlyStopping(patience=patience, verbose=True)
best_val_loss = float("inf")
# 8. 训练循环
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
epoch_train_loss = running_loss / len(train_loader.dataset)
# 在验证集上评估
model.eval()
val_running_loss = 0.0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
loss = criterion(outputs, labels)
val_running_loss += loss.item() * images.size(0)
epoch_val_loss = val_running_loss / len(val_loader.dataset)
print(f"Epoch [{epoch+1}/{num_epochs}], "
f"Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")
# 检查是否是最优,如果是则保存
if epoch_val_loss < best_val_loss:
best_val_loss = epoch_val_loss
torch.save(model.state_dict(), model_save_path)
print(" -> Model improved; saving current model.")
# EarlyStopping 检测
early_stopping(epoch_val_loss)
if early_stopping.early_stop:
print("Early stopping triggered!")
break
print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))
print(f"Model is saved to: {model_save_path}")
if __name__ == "__main__":
train_swin_transformer()
代码说明
-
使用
timm.create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=2)
构建一个 Swin Transformer tiny 版本模型,加载预训练权重并将输出层改为2类。 -
ImageFolder
读取buy
,nobuy
两个文件夹的数据,自动识别类别标签。通过random_split
以 80:20 划分训练与验证集。 -
训练时每个 epoch 会输出一次训练集 loss、验证集 loss,若验证集上的 loss 在 10 个 epoch 内未改善,即停止训练。
-
最优模型权重会被保存在
model_save_path
文件。
2. 测试(推理)代码
import os
import timm
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
def predict_swin_transformer(
model_path="swin_transformer_best.pth",
target="test_image_or_dir",
class_names=("buy","nobuy")
):
"""
加载训练好的 Swin Transformer 模型,对单张图像或目录下所有图像进行预测。
:param model_path: 训练时保存的模型权重路径
:param target: 可以是单张图片路径,也可以是包含多张图片的目录路径
:param class_names: 类别名称,需与训练时顺序一致
"""
# 1. 加载模型
model = timm.create_model("swin_tiny_patch4_window7_224", pretrained=False, num_classes=2)
model.load_state_dict(torch.load(model_path))
model.eval()
# 2. 定义与训练时相同的图像转换
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def infer_image(img_path):
img = Image.open(img_path).convert("RGB")
input_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(input_tensor)
probs = F.softmax(outputs, dim=1).numpy().flatten()
pred_idx = probs.argmax()
return pred_idx, probs
# 判断是单张图片还是文件夹
if os.path.isfile(target):
# 单张图片
idx, prob = infer_image(target)
print(f"Image: {target}")
print(f" -> Predicted: {class_names[idx]}, Probability: {prob[idx]:.4f}")
elif os.path.isdir(target):
# 目录下所有图片
for file_name in os.listdir(target):
file_path = os.path.join(target, file_name)
if os.path.isfile(file_path):
idx, prob = infer_image(file_path)
print(f"Image: {file_path}")
print(f" -> Predicted: {class_names[idx]}, Probability: {prob[idx]:.4f}")
else:
print("Error: target path is neither a file nor a directory.")
if __name__ == "__main__":
predict_swin_transformer(
model_path="swin_transformer_best.pth",
target="test_image_or_dir", # 改成想预测的路径
class_names=("buy", "nobuy")
)
预测说明
-
通过
model.eval()
切换到推理模式,确保不会更新参数。 -
对于每个图像,输出属于两类的概率,并根据最大概率所在索引确定预测类别。
-
可以指定
target
为具体文件,或者一个文件夹。
二、EfficientNet
1. 训练代码
下面演示使用 timm
中的 EfficientNet-B0 模型进行二分类的训练流程。其他版本(B1-B7)只需在 timm.create_model("efficientnet_b0")
中更改模型名称即可。
import os
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
class EarlyStopping:
def __init__(self, patience=10, verbose=False):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_loss = None
self.early_stop = False
def __call__(self, val_loss):
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss < self.best_loss:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
if self.verbose:
print(f"EarlyStopping Counter: {self.counter} / {self.patience}")
if self.counter >= self.patience:
self.early_stop = True
def train_efficientnet(
data_root="dataset_root",
model_save_path="efficientnet_best.pth",
batch_size=8,
num_epochs=50,
patience=10,
lr=1e-4
):
"""
使用 EfficientNet-B0 模型进行二分类训练示例。
"""
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
dataset = datasets.ImageFolder(root=data_root, transform=transform)
# 拆分 80/20
dataset_size = len(dataset)
val_size = int(dataset_size * 0.2)
train_size = dataset_size - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# 模型: EfficientNet-B0, 2分类
model = timm.create_model("efficientnet_b0", pretrained=True, num_classes=2)
# 损失与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
early_stopping = EarlyStopping(patience=patience, verbose=True)
best_val_loss = float("inf")
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
epoch_train_loss = running_loss / len(train_loader.dataset)
# 验证
model.eval()
val_running_loss = 0.0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
loss = criterion(outputs, labels)
val_running_loss += loss.item() * images.size(0)
epoch_val_loss = val_running_loss / len(val_loader.dataset)
print(f"Epoch [{epoch+1}/{num_epochs}], "
f"Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")
# 保存最好模型
if epoch_val_loss < best_val_loss:
best_val_loss = epoch_val_loss
torch.save(model.state_dict(), model_save_path)
print(" -> Model improved; saving current model.")
# 早停判断
early_stopping(epoch_val_loss)
if early_stopping.early_stop:
print("Early stopping triggered!")
break
print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))
print(f"Model is saved to: {model_save_path}")
if __name__ == "__main__":
train_efficientnet()
2. 测试(推理)代码
import os
import timm
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
def predict_efficientnet(
model_path="efficientnet_best.pth",
target="test_image_or_dir",
class_names=("buy","nobuy")
):
# 1. 创建模型 & 加载权重
model = timm.create_model("efficientnet_b0", pretrained=False, num_classes=2)
model.load_state_dict(torch.load(model_path))
model.eval()
# 2. 预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
def infer_image(img_path):
img = Image.open(img_path).convert("RGB")
input_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(input_tensor)
probs = F.softmax(outputs, dim=1).numpy().flatten()
pred_idx = probs.argmax()
return pred_idx, probs
if os.path.isfile(target):
# 单图片
idx, prob = infer_image(target)
print(f"Image: {target}")
print(f" -> Predicted: {class_names[idx]}, Probability: {prob[idx]:.4f}")
elif os.path.isdir(target):
# 文件夹
for file_name in os.listdir(target):
file_path = os.path.join(target, file_name)
if os.path.isfile(file_path):
idx, prob = infer_image(file_path)
print(f"Image: {file_path}")
print(f" -> Predicted: {class_names[idx]}, Probability: {prob[idx]:.4f}")
else:
print("Error: target path is neither a file nor a directory.")
if __name__ == "__main__":
predict_efficientnet(
model_path="efficientnet_best.pth",
target="test_image_or_dir",
class_names=("buy", "nobuy")
)
三、ConvNeXt
1. 训练代码
import os
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
class EarlyStopping:
def __init__(self, patience=10, verbose=False):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_loss = None
self.early_stop = False
def __call__(self, val_loss):
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss < self.best_loss:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
if self.verbose:
print(f"EarlyStopping Counter: {self.counter} / {self.patience}")
if self.counter >= self.patience:
self.early_stop = True
def train_convnext(
data_root="dataset_root",
model_save_path="convnext_best.pth",
batch_size=8,
num_epochs=50,
patience=10,
lr=1e-4
):
"""
使用 ConvNeXt Tiny 模型进行二分类训练示例。
"""
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
dataset = datasets.ImageFolder(root=data_root, transform=transform)
# 80/20 拆分
dataset_size = len(dataset)
val_size = int(dataset_size * 0.2)
train_size = dataset_size - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# 使用 timm 中的 convnext_tiny, 2分类
model = timm.create_model("convnext_tiny", pretrained=True, num_classes=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
early_stopping = EarlyStopping(patience=patience, verbose=True)
best_val_loss = float("inf")
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
epoch_train_loss = running_loss / len(train_loader.dataset)
# 验证
model.eval()
val_running_loss = 0.0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
loss = criterion(outputs, labels)
val_running_loss += loss.item() * images.size(0)
epoch_val_loss = val_running_loss / len(val_loader.dataset)
print(f"Epoch [{epoch+1}/{num_epochs}], "
f"Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")
if epoch_val_loss < best_val_loss:
best_val_loss = epoch_val_loss
torch.save(model.state_dict(), model_save_path)
print(" -> Model improved; saving current model.")
early_stopping(epoch_val_loss)
if early_stopping.early_stop:
print("Early stopping triggered!")
break
print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))
print(f"Model is saved to: {model_save_path}")
if __name__ == "__main__":
train_convnext()
2. 测试(推理)代码
import os
import timm
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
def predict_convnext(
model_path="convnext_best.pth",
target="test_image_or_dir",
class_names=("buy","nobuy")
):
# 1. 创建模型 & 加载权重
model = timm.create_model("convnext_tiny", pretrained=False, num_classes=2)
model.load_state_dict(torch.load(model_path))
model.eval()
# 2. 与训练一致的预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
def infer_image(img_path):
img = Image.open(img_path).convert("RGB")
input_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(input_tensor)
probs = F.softmax(outputs, dim=1).numpy().flatten()
pred_idx = probs.argmax()
return pred_idx, probs
if os.path.isfile(target):
idx, prob = infer_image(target)
print(f"Image: {target}")
print(f" -> Predicted: {class_names[idx]}, Probability: {prob[idx]:.4f}")
elif os.path.isdir(target):
for file_name in os.listdir(target):
file_path = os.path.join(target, file_name)
if os.path.isfile(file_path):
idx, prob = infer_image(file_path)
print(f"Image: {file_path}")
print(f" -> Predicted: {class_names[idx]}, Probability: {prob[idx]:.4f}")
else:
print("Error: target path is neither a file nor a directory.")
if __name__ == "__main__":
predict_convnext(
model_path="convnext_best.pth",
target="test_image_or_dir",
class_names=("buy", "nobuy")
)
总结与注意事项
-
模型选择:给出了 Swin Transformer(基于视觉Transformer)、EfficientNet(高效卷积网络)与 ConvNeXt(现代卷积网络)的PyTorch实现案例。它们均可在小数据集上通过微调预训练权重获得较好的精度。
-
数据准备:
-
需要将图片放在
dataset_root/buy/
与dataset_root/nobuy/
,分别代表两类。 -
代码中
train_*()
函数会自动使用ImageFolder
读取并做 80:20 的训练/验证拆分。
-
-
早停策略:如果验证集的 loss 在
patience=10
个 epoch 内不再改善,脚本会停止训练并保留最后一次的最优模型。 -
保存与加载:训练完成后,会将最优模型(基于最低验证 loss)保存在指定
.pth
文件中。推理脚本会加载该权重用于评估。 -
测试脚本:支持对单张图片或目录下所有图片批量预测,输出二分类概率。
-
从头训练:若想从头训练,可在
timm.create_model(..., pretrained=False, ...)
中将pretrained
设为False
,但通常不建议在仅千级数据时从零开始,会引起过拟合,精度不如微调法。 -
扩展:如果需要更多的正则化、数据增强(如随机裁剪、随机水平翻转、AutoAugment等),可在
transforms
中添加相应操作,以进一步提升模型在小数据集上的泛化能力。
以上示例可帮你分别训练与测试 3 种优秀的图像二分类模型,满足你在 精度优先、80%训练+20%验证、早停保存模型 以及 推理阶段输出概率 等需求。若有更多个性化需求(如多卡训练、学习率调度、混合精度训练等),可在此基础上进行拓展。祝你实验顺利!