PyTorch图像分类——本地数据集食物分类识别

一、数据准备与预处理

1.1 本地数据集的数据路径生成

def train_test_file(root, dir):
    file_txt = open(dir+'.txt','w')
    path = os.path.join(root,dir)
    for roots, directories, files in os.walk(path):
        if directories:  # 获取类别目录
            dirs = directories
        else:
            now_dir = roots.split('\\')[-1]  # 提取当前类别
            for file in files:
                path_1 = os.path.join(roots,file)
                file_txt.write(f"{path_1} {dirs.index(now_dir)}\n")

关键技术点

  • os.walk递归遍历目录树
  • 路径拼接使用os.path.join保证跨平台兼容性
  • 文件路径与标签的映射存储
  • 类别自动编号机制

1.2 数据转换配置

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([256,256]),  # 统一尺寸
        transforms.ToTensor(),         # 张量转换
    ]),
    'valid': transforms.Compose([
        transforms.Resize([256,256]),
        transforms.ToTensor(),
    ])
}

扩展建议

  • 训练集增加数据增强:
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomRotation(15)
    
  • 验证集添加标准化:
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
    

二、自定义数据集类

2.1 Dataset类实现

class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.imgs, self.labels = [], []
        with open(file_path) as f:
            for line in f:
                img_path, label = line.strip().split()
                self.imgs.append(img_path)
                self.labels.append(int(label))  # 转换为整型

    def __getitem__(self, index):
        image = Image.open(self.imgs[index])
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(self.labels[index])

关键技术点

  • 继承torch.utils.data.Dataset必须实现三个方法

  • 使用PIL.Image保证图像读取兼容性

  • 延迟加载机制(非一次性加载所有图像)


三、数据加载与批处理

3.1 DataLoader配置

train_loader = DataLoader(
    dataset=training_data,
    batch_size=64,        # 批大小
    shuffle=True,         # 训练集随机洗牌
    num_workers=4,        # 多进程加载
    pin_memory=True       # 加速GPU传输
)

test_loader = DataLoader(
    test_data,
    batch_size=64,
    shuffle=False         # 测试集无需洗牌
)

参数解析

参数作用推荐值
batch_size内存利用率/梯度稳定性平衡32-256
shuffle打破数据顺序相关性仅训练集启用
num_workers并行加载进程数CPU核心数-2
pin_memory锁页内存加速拷贝GPU训练时启用

四、卷积神经网络架构

4.1 网络结构定义

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),  # 3→16通道
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),                 # 尺寸减半
            
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(inplace=True),
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*64*64, 256),  # 全连接层
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 20)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

结构分析

在这里插入图片描述

维度变化

输入: (3, 256, 256)
Conv1 → (16, 256, 256)
Pool1 → (16, 128, 128)
Conv2 → (32, 128, 128)
Pool2 → (32, 64, 64)
Conv3 → (64, 64, 64)
Flatten → 64*64*64 = 262,144

五、训练流程剖析

5.1 训练函数

def train(dataloader, model, loss_fn, optimizer):
    model.train()  # 训练模式
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        # 前向传播
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # 反向传播
        optimizer.zero_grad()  # 梯度清零
        loss.backward()        # 自动微分
        optimizer.step()       # 参数更新
        
        # 监控日志
        if batch % 10 == 0:
            print(f"Batch {batch}: loss={loss.item():.4f}")

关键技术点

  • model.train()启用BatchNorm和Dropout
  • 设备转移的规范写法
  • 梯度清零的必要性
  • 损失计算的三种情况:
    • 单标签分类:CrossEntropyLoss
    • 多标签分类:BCEWithLogitsLoss
    • 回归任务:MSELoss

5.2 优化器配置

optimizer = torch.optim.Adam(
    params=model.parameters(),
    lr=1e-3,            # 初始学习率
    betas=(0.9, 0.999), # 动量参数
    weight_decay=1e-4   # L2正则化
)

scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=5,        # 每5个epoch
    gamma=0.1           # 学习率衰减
)

优化策略

  • Adam优化器自适应学习率
  • 学习率调度器实现动态调整
  • 权重衰减防止过拟合

六、模型验证与测试

6.1 验证函数实现

def test(dataloader, model, loss_fn):
    model.eval()  # 评估模式
    total_loss = 0
    correct = 0
    
    with torch.no_grad():  # 禁用梯度
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            
            # 累计损失
            total_loss += loss_fn(pred, y).item()
            
            # 计算准确率
            correct += (pred.argmax(1) == y).sum().item()
            
            # 样本级诊断
            for pred_idx, true_idx in zip(pred.argmax(1), y):
                print(f"预测: {labels[pred_idx.item()]} | 真实: {labels[true_idx.item()]}")

    avg_loss = total_loss / len(dataloader)
    accuracy = 100 * correct / len(dataloader.dataset)
    print(f"测试结果: 准确率={accuracy:.2f}%, 平均损失={avg_loss:.4f}")

关键技术点

  • model.eval()关闭Dropout和BatchNorm更新
  • torch.no_grad()减少内存消耗
  • 两种评估指标:损失值 + 准确率
  • 预测结果的可视化诊断

七、完整训练循环

7.1 主训练流程

epochs = 20
best_acc = 0

for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    
    # 训练阶段
    train(train_loader, model, loss_fn, optimizer)
    
    # 验证阶段
    acc = test(test_loader, model, loss_fn)
    
    # 模型保存
    if acc > best_acc:
        torch.save(model.state_dict(), "best_model.pth")
        best_acc = acc

print("训练完成!最佳准确率:", best_acc)

八、完整代码

import torch
from torch import nn
from torch.utils.data import Dataset,DataLoader
import numpy as np
from PIL import Image
from torchvision import transforms

labels = {0:'八宝粥',
          1:'哈密瓜',
          2:'圣女果',
          3:'巴旦木',
          4:'板栗',
          5:'汉堡',
          6:'火龙果',
          7:'炸鸡',
          8:'瓜子',
          9:'生肉',
          10:'白萝卜',
          11:'胡萝卜',
          12:'草莓',
          13:'菠萝',
          14:'薯条',
          15:'蛋',
          16:'蛋挞',
          17:'青菜',
          18:'骨肉相连',
          19:'鸡翅'}

data_transforms = {
    'train':
        transforms.Compose([
            transforms.Resize([256,256]),
            transforms.ToTensor(),
        ]),
    'valid':
        transforms.Compose([
            transforms.Resize([256,256]),
            transforms.ToTensor(),
        ]),
}
class food_dataset(Dataset):
    def __init__(self,file_path,transform=None):
        self.file_path = file_path
        self.imgs = []
        self.labels =[]
        self.transform = transform
        with open(self.file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)
    def __len__(self):
        return len(self.imgs)
    def __getitem__(self, index):
        image = Image.open(self.imgs[index])
        if self.transform :
            image = self.transform(image)

        label = self.labels[index]
        label = torch.from_numpy(np.array(label,dtype = np.int64))
        return image,label

training_data = food_dataset(file_path = './train.txt',transform = data_transforms['train'])
test_data =food_dataset(file_path = './test.txt',transform = data_transforms['valid'])

train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3,16,3,1,1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(16,32,3,1,1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(32,64,3,1,1),
            nn.ReLU(),
        )
        self.out = nn.Linear(64*64*64,20)
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0),-1)
        x = self.out(x)
        return x

model = CNN().to(device)

def train(dataloader,model,loss_fn,optimizer):
    model.train()

    batch_size_num = 1

    for X,y in dataloader:
        X,y = X.to(device),y.to(device)
        pred = model(X)
        loss = loss_fn(pred,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_value = loss.item()
        if batch_size_num % 1 == 0:
            loss, current = loss.item(), batch_size_num * len(X)
            print(f"loss: {loss:>7f}  [number: {batch_size_num}]")
        batch_size_num += 1


def test(dataloader,model,loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0,0
    with torch.no_grad():
        for X,y in dataloader:
            X,y = X.to(device),y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred,y).item()
            a = pred.argmax(1).tolist()
            b = y.tolist()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            for i in zip(pred.argmax(1).tolist(),y.tolist()):
                print(f"当前测试的结果为:{labels[i[0]]} \t 当前真实的结果为:{labels[i[1]]}")

    test_loss /= num_batches
    correct /= size
    print(f"Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss} ")

loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(),lr=0.01)

epochs = 10
for e in range(epochs):
    print(f"Epoch {e+1}\n-------------------------------")
    train(train_dataloader,model,loss_fn,optimizer)
    print()
print("Done!")
test(test_dataloader,model,loss_fn)

九、常见问题排查

9.1 数据相关

  • 问题:Loss不下降
  • 检查:
    1. 数据路径是否正确
    2. 标签是否对应
    3. 输入数据是否归一化

9.2 模型相关

  • 问题:准确率随机
  • 检查:
    1. 最后一层是否忘记加激活函数
    2. 学习率是否过大
    3. 梯度是否爆炸(添加梯度裁剪)

9.3 训练相关

  • 问题:GPU利用率低
  • 优化:
    1. 增大batch_size
    2. 启用pin_memory
    3. 增加prefetch_factor

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值