1、摘要
1.1、声明
*本博客重实际操作轻理论,按照流程既可完成自己需要实现的图像多分类问题,不再追叙官方模型及相关文献。
官方demo:
https://huggingface.co/spaces/pytorch/ResNet
1.2、ResNet18/34
ResNet34,即深度残差网络(Residual Network)的一种变体,特指包含34层卷积层的深度神经网络结构。ResNet34是一种具有深度残差结构的卷积神经网络,通过引入跳跃连接和瓶颈设计等创新点,解决了深度神经网络训练中的难题,并在图像分类、目标检测、图像分割等多个领域得到了广泛应用。
1.3、torch/torchvision
框架和模型构建相关。
*理论和原理建议参照如下:
resnet34文档:
http://pytorch.org/vision/main/models/generated/torchvision.models.resnet34.html
torchvision文档:
https://pytorch.org/vision/stable/index.html
2、数据准备
所有数据依据自己需要的类别进行分类,并且自定义分类名称类别。
3、训练
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2024/7/8
# @Author : CCM
# @Describe : 图像多分类训练代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torchvision.models import ResNet34_Weights
from PIL import Image
import os
# 检查数据集中的文件
def check_dataset(data_dir_):
for root, dirs, files in os.walk(data_dir_):
for file in files:
if file.endswith(('.jpg', '.png', '.jpeg')):
try:
img = Image.open(os.path.join(root, file))
img.verify() # 验证图像是否完整
except (IOError, SyntaxError) as e:
print(f"文件损坏: {os.path.join(root, file)}")
data_dir = 'data'
check_dataset(data_dir)
# 数据预处理
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=0)
for x in ['train', 'val']} # 将 num_workers 改为 0
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 打印训练和验证的数据量
print(f"训练数据量: {dataset_sizes['train']}")
print(f"验证数据量: {dataset_sizes['val']}")
# 定义模型
model_ft = models.resnet34(weights=ResNet34_Weights.DEFAULT)
num_frs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_frs, 3) # 修改为3个类别
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
# 优化器
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
# 学习率调整策略
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
# 训练模型
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 40)
"""
篇幅原因需要私信获取
"""
return model
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)
# 保存模型
torch.save(model_ft.state_dict(), 'resnet34_indoor_ccm.pth')
4、推理
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2024/7/8
# @Author : CCM
# @Describe : 使用训练模型进行推理,此为三分类
import torch
import torch.nn as nn
from torchvision import models, transforms
from torchvision.models import ResNet34_Weights
from PIL import Image
# 数据预处理
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 加载模型
model = models.resnet34(weights=ResNet34_Weights.DEFAULT) # 使用最新的权重枚举
num_frs = model.fc.in_features
model.fc = nn.Linear(num_frs, 3) # 修改为3个类别
# 加载训练好的模型权重
model.load_state_dict(torch.load('resnet34_indoor_ccm.pth'))
model.eval()
# 定义类别
class_names = ['indoor', 'outdoor', 'noncom_pliant']
# 定义类别映射字典
class_mapping = {
'indoor': '室内',
'outdoor': '室外',
'noncom_pliant': '不合规'
}
# 推理函数
def predict_image(image_path_):
image = Image.open(image_path_)
image = data_transforms(image)
image = image.unsqueeze(0) # 增加一个批次维度
with torch.no_grad():
outputs = model(image)
_, pred = torch.max(outputs, 1)
return class_mapping[class_names[pred[0]]]
# 示例推理
image_path = '65.jpg' # 替换为你要推理的图像路径
predicted_class = predict_image(image_path)
print(f'图像分类为: {predicted_class}')
5、总结
Model structure Top-1 error Top-5 error resnet18 30.24 10.92 resnet34 26.70 8.58 resnet50 23.85 7.13 resnet101 22.63 6.44 resnet152 21.69 5.94 其余ResNet的使用及操作后续逐步更新。