PyTorch迁移学习实战:基于ResNet的CIFAR-10图像分类

1. 环境配置与数据准备  

1.1开发环境

| 库名称       | 版本号                  | 编译版本                | 备注                     |
|-------------|-------------------------|-------------------------|--------------------------|
| **PyTorch** | 2.5.0                  | cu118                   | CUDA 11.8加速支持        |
| **Torchvision** | 0.20.0         | cu118                   | 配套图像处理工具集       |
| **CUDA**    | 11.8                    | -                       | 需与PyTorch版本严格匹配  |

1.2CIFAR-10与ImageNet的核心差异

CIFAR-10数据集由60,000张 32×32像素 的RGB三通道小尺寸图像组成,共10个类别,其分辨率仅为ImageNet标准输入尺寸(224×224)的 1/50,且图像内容通常为完整目标的中心裁剪,而ImageNet多为高分辨率场景化图片。这一特性导致直接应用ImageNet预训练模型时需改造输入层(如替换首层卷积核与池化层),并需针对性设计数据增强策略(如缩小随机裁剪比例、降低模糊强度),以避免小尺寸图像因下采样过度丢失细节特征。

2.代码部分

2.1导入PyTorch相关库及工具,处理图像并训练模型

import torch
from torch import nn
import torchvision
from torchvision import models
from torchvision import datasets,transforms
from datetime import datetime
import sys
from tqdm import tqdm
import numpy as np

2.2配置模型训练参数

# --------------------------
# 数据集配置
# --------------------------
# input_shape = 32  # 图像输入尺寸(假设为32x32像素,适用于CIFAR-10等数据集)
num_classes = 10     # 分类任务的目标类别数(例如:CIFAR-10的10个类别)

# --------------------------
# 超参数配置
# --------------------------
batch_size = 32       # 批量大小(根据GPU显存调整,典型值:32/64/128)
num_epochs = 5        # 训练轮次(整个数据集遍历次数)
learning_rate = 1e-3  # 初始学习率(常用范围:1e-5 到 1e-2)

# --------------------------
# 硬件加速配置
# --------------------------
# 自动检测并设置计算设备(优先使用GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

1.3预处理并加载CIFAR-10训练和测试数据

# 定义数据预处理流程组合
transform = transforms.Compose([
    # 调整图像尺寸到224x224像素(原始CIFAR-10图像为32x32,此处通过插值放大)
    transforms.Resize(size=(224,224)),
    
    # 将PIL图像或numpy数组转换为PyTorch张量
    # 执行以下操作:
    # 1. 维度变换:(高度, 宽度, 通道) → (通道, 高度, 宽度)
    # 2. 值域转换:将像素值从0-255整数 → 0.0-1.0浮点数
    transforms.ToTensor(),
    
    # 标准化处理(使用CIFAR-10数据集统计值)
    # 计算方式:标准化后像素值 = (原始像素值 - 均值) / 标准差
    transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465),  # RGB三通道均值(来自CIFAR-10训练集统计)
        std=(0.2023, 0.1994, 0.2010)    # RGB三通道标准差
    )
])

# 创建CIFAR-10训练数据集实例
train_dataset = datasets.CIFAR10(
    root='CIFAR10/',   # 数据集存储根目录(自动创建该文件夹)
    train=True,        # 加载训练集(共50,000张图像)
    download=True,     # 如果本地不存在数据集则自动下载
    transform=transform # 对训练集应用上述预处理流程
)

# 创建CIFAR-10测试数据集实例
test_dataset = datasets.CIFAR10(
    root='CIFAR10/',   # 数据集存储路径(与训练集相同)
    train=False,       # 加载测试集(共10,000张图像)
    download=True,     # 同上
    transform=transform # 对测试集应用相同预处理
)

train_dataloader=torch.utils.data.DataLoader(dataset=train_dataset,
                                             shuffle=True,
                                             batch_size=batch_size)
test_dataloader=torch.utils.data.DataLoader(dataset=test_dataset,
                                             shuffle=False,
                                             batch_size=batch_size)

2.3加载预训练ResNet-18模型,修改全连接层适配10分类任务,并转移至计算设备

model = models.resnet18(pretrained=True)
in_features=model.fc.in_features
model.fc=nn.Linear(in_features,10)
model = model.to(device)

2.4定义交叉熵损失函数,使用SGD优化器,设置学习率、动量和权重衰减

# 定义损失函数:交叉熵损失(CrossEntropyLoss)
# 适用于多分类任务,内部自动结合Softmax和负对数似然损失
# 输入:模型输出(未经过Softmax的logits),目标标签
# 输出:标量损失值
criterion = nn.CrossEntropyLoss()

# 定义优化器:随机梯度下降(SGD)
# 使用SGD优化算法更新模型参数
optimizer = torch.optim.SGD(
    model.parameters(),  # 需要优化的模型参数(通过model.parameters()获取)
    lr=learning_rate,    # 学习率(控制参数更新步长)
    momentum=0.9,        # 动量因子(加速收敛并减少震荡)
    weight_decay=5e-4    # 权重衰减(L2正则化系数,防止过拟合)
)

2.5训练模型:遍历数据,计算损失和准确率,反向传播更新参数,定期打印进度

# 外层循环:遍历每个训练轮次(epoch)
for epoch in range(num_epochs):
    # 内层循环:遍历训练数据加载器中的每个批次(batch)
    # enumerate(train_dataloader) 返回 (batch索引, (图像数据, 标签))
    for batch_idx, (images, labels) in enumerate(train_dataloader):
        
        # 将当前批次的图像和标签转移到指定设备(GPU/CPU)
        images = images.to(device)  # 图像张量,形状为 [batch_size, 通道数, 高度, 宽度]
        labels = labels.to(device)  # 标签张量,形状为 [batch_size]
        
        # 前向传播:将图像输入模型得到预测结果
        out = model(images)  # 输出形状为 [batch_size, num_classes]
        
        # 计算损失:比较模型输出与真实标签的差异
        loss = criterion(out, labels)  # 返回标量损失值
        
        # 计算当前批次的准确率
        # out.argmax(axis=1): 获取每个样本预测概率最高的类别索引
        # (out.argmax(axis=1) == labels): 比较预测类别与真实类别,返回布尔张量
        # .sum().item(): 统计正确预测的数量并转换为Python标量
        n_corrects = (out.argmax(axis=1) == labels).sum().item()
        
        # 计算当前批次的准确率
        acc = n_corrects / labels.size(0)  # labels.size(0) 是当前批次的样本数
        
        # 梯度清零:清空优化器中缓存的梯度(防止梯度累加)
        optimizer.zero_grad()
        
        # 反向传播:计算损失关于模型参数的梯度
        loss.backward()
        
        # 参数更新:根据梯度下降算法更新模型参数
        optimizer.step()
        
        # 每100个批次打印一次训练状态
        if (batch_idx + 1) % 100 == 0:
            # 打印格式:
            # 当前时间, 当前epoch/总epoch数, 当前batch/总batch数, 当前损失值, 当前准确率
            print(f'{datetime.now()}, {epoch + 1}/{num_epochs}, {batch_idx + 1}/{total_batch}: {loss.item():.4f}, acc: {acc}')

2.6测试模型:遍历测试数据,统计预测正确率并输出

# 初始化累计统计变量
total = 0    # 累计测试样本总数
correct = 0  # 累计正确预测数量

# 遍历测试数据加载器中的每个批次
# 使用 tqdm 包装器显示进度条
for images, labels in tqdm(test_dataloader):
    # 获取当前批次的图像和标签
    images = images.to(device)  # 当前批次图像,形状为 [batch_size, 通道数, 高度, 宽度]
    labels = labels.to(device)  # 当前批次标签,形状为 [batch_size]
    
    # 前向传播:将图像输入模型得到预测结果
    out = model(images)  # 输出形状为 [batch_size, num_classes]
    
    # 获取每个样本预测概率最高的类别索引
    preds = torch.argmax(out, dim=1)  # 形状为 [batch_size]
    
    # 累加当前批次的样本数到总数
    total += images.size(0)  # images.size(0) 是当前批次的样本数
    
    # 累加当前批次的正确预测数量
    correct += (preds == labels).sum().item()  # 统计正确预测数并转换为Python标量

# 打印测试结果
# 格式:正确预测数/总样本数=准确率
print(f'{correct}/{total}={correct/total}')

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值