👋 你好!这里有实用干货与深度分享✨✨ 若有帮助,欢迎:
👍 点赞 | ⭐ 收藏 | 💬 评论 | ➕ 关注 ,解锁更多精彩!
📁 收藏专栏即可第一时间获取最新推送🔔。
📖后续我将持续带来更多优质内容,期待与你一同探索知识,携手前行,共同进步🚀。
模型训练流程
本文详细介绍深度学习模型的训练流程,包括数据加载、训练循环、验证评估、训练监控等核心组件,帮助你高效、规范地完成模型训练。
1. 数据加载
1.1 数据集封装
自定义数据集类,支持不同数据格式和预处理需求。
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, data, labels, transform=None):
self.data = data
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
label = self.labels[idx]
if self.transform:
sample = self.transform(sample)
return sample, label
# 示例
train_dataset = CustomDataset(train_data, train_labels, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
1.2 数据预处理流水线
灵活的数据增强和标准化处理。
from torchvision import transforms
def create_transform_pipeline(is_training=True):
transform_list = [
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]
if is_training:
transform_list.extend([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2)
])
return transforms.Compose(transform_list)
# 示例
train_transform = create_transform_pipeline(is_training=True)
val_transform = create_transform_pipeline(is_training=False)
2. 训练循环
2.1 训练函数
标准训练过程,包含损失与准确率统计。
def train_epoch(model, train_loader, criterion, optimizer, device, log_interval=100):
model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# 前向传播
output = model(data)
loss = criterion(output, target)
# 反向传播
loss.backward()
# 梯度更新
optimizer.step()
# 统计
total_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total += target.size(0)
if batch_idx % log_interval == 0:
print(f'Train Batch: {batch_idx}/{len(train_loader)} '
f'Loss: {loss.item():.6f} '
f'Acc: {100.*correct/total:.2f}%')
return total_loss / len(train_loader), correct / total
2.2 验证函数
评估模型在验证集上的表现。
def validate(model, val_loader, criterion, device):
model.eval()
val_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
# 前向传播
output = model(data)
val_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total += target.size(0)
val_loss /= len(val_loader)
accuracy = correct / total
print(f'\nValidation set: Average loss: {val_loss:.4f}, '
f'Accuracy: {100.*accuracy:.2f}%\n')
return val_loss, accuracy
3. 完整训练流程
3.1 训练主函数
集成训练与验证,自动保存最佳模型。
def train(model, train_loader, val_loader, criterion, optimizer, device, num_epochs, save_path):
best_val_acc = 0
train_losses = []
val_losses = []
for epoch in range(num_epochs):
print(f'\nEpoch {epoch+1}/{num_epochs}')
print('-' * 10)
# 训练阶段
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
train_losses.append(train_loss)
# 验证阶段
val_loss, val_acc = validate(model, val_loader, criterion, device)
val_losses.append(val_loss)
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_acc': best_val_acc
}, save_path)
return train_losses, val_losses
3.2 训练执行示例
完整训练流程示例,便于快速上手。
import torch
import torch.nn as nn
import torch.optim as optim
# 设置设备 GPU还是CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化模型、损失函数、优化器
model = YourModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
# 开始训练
train_losses, val_losses = train(
model=model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=optimizer,
device=device,
num_epochs=100,
save_path='best_model.pth'
)
4. 训练监控
4.1 训练过程可视化
直观展示训练与验证损失变化。
import matplotlib.pyplot as plt
def plot_training_history(train_losses, val_losses):
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()
4.2 训练日志记录
记录训练过程,便于追踪和复现。
import logging
import time
def setup_logger(log_file):
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
return logging.getLogger(__name__)
logger = setup_logger('training.log')
logger.info('Starting training...')
5. 实践建议
-
数据加载优化
- 使用
num_workers
进行多进程数据加载 - 合理选择
batch_size
- 使用
pin_memory=True
加速GPU训练
- 使用
-
训练稳定性
- 使用梯度裁剪防止梯度爆炸
- 实现断点续训功能
- 定期保存检查点
-
内存管理
- 及时清理不需要的中间变量
- 使用梯度累积处理大批量数据
- 监控GPU内存使用情况
📌 感谢阅读!若文章对你有用,别吝啬互动~
👍 点个赞 | ⭐ 收藏备用 | 💬 留下你的想法 ,关注我,更多干货持续更新!