图像分类-模型的训练及推理

1、摘要

1.1、声明

*本博客重实际操作轻理论,按照流程既可完成自己需要实现的图像多分类问题,不再追叙官方模型及相关文献。

官方demo:
https://huggingface.co/spaces/pytorch/ResNet

1.2、ResNet18/34

ResNet34,即深度残差网络(Residual Network)的一种变体,特指包含34层卷积层的深度神经网络结构。ResNet34是一种具有深度残差结构的卷积神经网络,通过引入跳跃连接和瓶颈设计等创新点,解决了深度神经网络训练中的难题,并在图像分类、目标检测、图像分割等多个领域得到了广泛应用。

1.3、torch/torchvision

框架和模型构建相关。
*理论和原理建议参照如下:

resnet34文档:
http://pytorch.org/vision/main/models/generated/torchvision.models.resnet34.html

torchvision文档:
https://pytorch.org/vision/stable/index.html

2、数据准备
 

所有数据依据自己需要的类别进行分类,并且自定义分类名称类别。

3、训练

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2024/7/8
# @Author : CCM
# @Describe : 图像多分类训练代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torchvision.models import ResNet34_Weights
from PIL import Image
import os


# 检查数据集中的文件
def check_dataset(data_dir_):
    for root, dirs, files in os.walk(data_dir_):
        for file in files:
            if file.endswith(('.jpg', '.png', '.jpeg')):
                try:
                    img = Image.open(os.path.join(root, file))
                    img.verify()  # 验证图像是否完整
                except (IOError, SyntaxError) as e:
                    print(f"文件损坏: {os.path.join(root, file)}")


data_dir = 'data'
check_dataset(data_dir)

# 数据预处理
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=0)
               for x in ['train', 'val']}  # 将 num_workers 改为 0
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 打印训练和验证的数据量
print(f"训练数据量: {dataset_sizes['train']}")
print(f"验证数据量: {dataset_sizes['val']}")

# 定义模型
model_ft = models.resnet34(weights=ResNet34_Weights.DEFAULT)
num_frs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_frs, 3)  # 修改为3个类别

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# 优化器
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# 学习率调整策略
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)


# 训练模型
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 40)
        """
        篇幅原因需要私信获取
        """
    return model


model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)
# 保存模型
torch.save(model_ft.state_dict(), 'resnet34_indoor_ccm.pth')

4、推理

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2024/7/8
# @Author : CCM
# @Describe : 使用训练模型进行推理,此为三分类


import torch
import torch.nn as nn
from torchvision import models, transforms
from torchvision.models import ResNet34_Weights
from PIL import Image


# 数据预处理
data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 加载模型
model = models.resnet34(weights=ResNet34_Weights.DEFAULT)  # 使用最新的权重枚举
num_frs = model.fc.in_features
model.fc = nn.Linear(num_frs, 3)  # 修改为3个类别

# 加载训练好的模型权重
model.load_state_dict(torch.load('resnet34_indoor_ccm.pth'))
model.eval()

# 定义类别
class_names = ['indoor', 'outdoor', 'noncom_pliant']
# 定义类别映射字典
class_mapping = {
    'indoor': '室内',
    'outdoor': '室外',
    'noncom_pliant': '不合规'
}


# 推理函数
def predict_image(image_path_):
    image = Image.open(image_path_)
    image = data_transforms(image)
    image = image.unsqueeze(0)  # 增加一个批次维度
    with torch.no_grad():
        outputs = model(image)
        _, pred = torch.max(outputs, 1)
        return class_mapping[class_names[pred[0]]]


# 示例推理
image_path = '65.jpg'  # 替换为你要推理的图像路径
predicted_class = predict_image(image_path)
print(f'图像分类为: {predicted_class}')

5、总结

Model structureTop-1 errorTop-5 error
resnet1830.2410.92
resnet3426.708.58
resnet5023.857.13
resnet10122.636.44
resnet15221.695.94

其余ResNet的使用及操作后续逐步更新。

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值