[训练和优化] 1. 训练流程

👋 你好!这里有实用干货与深度分享✨✨ 若有帮助,欢迎:​
👍 点赞 | ⭐ 收藏 | 💬 评论 | ➕ 关注 ,解锁更多精彩!​
📁 收藏专栏即可第一时间获取最新推送🔔。​
📖后续我将持续带来更多优质内容,期待与你一同探索知识,携手前行,共同进步🚀。​



人工智能

模型训练流程

本文详细介绍深度学习模型的训练流程,包括数据加载、训练循环、验证评估、训练监控等核心组件,帮助你高效、规范地完成模型训练。


1. 数据加载

1.1 数据集封装

自定义数据集类,支持不同数据格式和预处理需求。

from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            sample = self.transform(sample)

        return sample, label

# 示例
train_dataset = CustomDataset(train_data, train_labels, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

1.2 数据预处理流水线

灵活的数据增强和标准化处理。

from torchvision import transforms

def create_transform_pipeline(is_training=True):
    transform_list = [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]

    if is_training:
        transform_list.extend([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2)
        ])

    return transforms.Compose(transform_list)

# 示例
train_transform = create_transform_pipeline(is_training=True)
val_transform = create_transform_pipeline(is_training=False)

2. 训练循环

2.1 训练函数

标准训练过程,包含损失与准确率统计。

def train_epoch(model, train_loader, criterion, optimizer, device, log_interval=100):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        # 前向传播
        output = model(data)
        loss = criterion(output, target)

        # 反向传播
        loss.backward()

        # 梯度更新
        optimizer.step()

        # 统计
        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total += target.size(0)

        if batch_idx % log_interval == 0:
            print(f'Train Batch: {batch_idx}/{len(train_loader)} '
                  f'Loss: {loss.item():.6f} '
                  f'Acc: {100.*correct/total:.2f}%')

    return total_loss / len(train_loader), correct / total

2.2 验证函数

评估模型在验证集上的表现。

def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)

            # 前向传播
            output = model(data)
            val_loss += criterion(output, target).item()

            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)

    val_loss /= len(val_loader)
    accuracy = correct / total

    print(f'\nValidation set: Average loss: {val_loss:.4f}, '
          f'Accuracy: {100.*accuracy:.2f}%\n')

    return val_loss, accuracy

3. 完整训练流程

3.1 训练主函数

集成训练与验证,自动保存最佳模型。

def train(model, train_loader, val_loader, criterion, optimizer, device, num_epochs, save_path):
    best_val_acc = 0
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # 训练阶段
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(train_loss)

        # 验证阶段
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        val_losses.append(val_loss)

        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_acc': best_val_acc
            }, save_path)

    return train_losses, val_losses

3.2 训练执行示例

完整训练流程示例,便于快速上手。

import torch
import torch.nn as nn
import torch.optim as optim

# 设置设备 GPU还是CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 初始化模型、损失函数、优化器
model = YourModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# 开始训练
train_losses, val_losses = train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    num_epochs=100,
    save_path='best_model.pth'
)

4. 训练监控

4.1 训练过程可视化

直观展示训练与验证损失变化。

import matplotlib.pyplot as plt

def plot_training_history(train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

4.2 训练日志记录

记录训练过程,便于追踪和复现。

import logging
import time

def setup_logger(log_file):
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

logger = setup_logger('training.log')
logger.info('Starting training...')

5. 实践建议

  1. 数据加载优化

    • 使用 num_workers 进行多进程数据加载
    • 合理选择 batch_size
    • 使用 pin_memory=True 加速GPU训练
  2. 训练稳定性

    • 使用梯度裁剪防止梯度爆炸
    • 实现断点续训功能
    • 定期保存检查点
  3. 内存管理

    • 及时清理不需要的中间变量
    • 使用梯度累积处理大批量数据
    • 监控GPU内存使用情况




📌 感谢阅读!若文章对你有用,别吝啬互动~​
👍 点个赞 | ⭐ 收藏备用 | 💬 留下你的想法 ,关注我,更多干货持续更新!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

斌zz

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值