【扒代码】train.py

import torch
from torch import optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
# 假设其他必要的模块已经导入

# 检查是否有可用的GPU,如果有,则使用GPU,否则使用CPU
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# 构建模型并将其移动到选择的设备上
model = build_model(args).to(device)

# 初始化参数字典,用于区分不同的模型参数
backbone_params = dict()
non_backbone_params = dict()

# 遍历模型的参数并根据其名称分配到相应的字典中
for n, p in model.named_parameters():
    if not p.requires_grad:
        continue  # 如果参数不需要梯度,则跳过
    if 'backbone' in n:
        backbone_params[n] = p  # 分配到backbone参数字典
    else:
        non_backbone_params[n] = p  # 分配到非backbone参数字典

# 创建AdamW优化器,它将使用不同的学习率来优化模型参数
optimizer = optim.AdamW(
    [
        {'params': non_backbone_params.values()},  # 非backbone参数
        {'params': backbone_params.values(), 'lr': args.backbone_lr}  # backbone参数,使用特定的学习率
    ],
    lr=args.lr,  # 基本学习率
    weight_decay=args.weight_decay,  # 权重衰减
)

# 创建学习率调度器,用于在训练过程中调整学习率
scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_drop, gamma=0.25)

# 如果需要恢复训练,则加载检查点
if args.resume_training:
    checkpoint = torch.load(os.path.join(args.model_path, f'{args.model_name}.pt'))
    model.load_state_dict(checkpoint['model'])  # 加载模型状态
    start_epoch = checkpoint['epoch']  # 设置起始训练轮次
    best = checkpoint['best_val_ae']  # 设置最佳验证误差
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器状态
    scheduler.load_state_dict(checkpoint['scheduler'])  # 加载学习率调度器状态
else:
    start_epoch = 0  # 如果不恢复训练,则从第0轮开始
    best = 10000000000000  # 设置一个较大的初始最佳误差值

# 创建目标函数,这里使用的是ObjectNormalizedL2Loss
criterion = ObjectNormalizedL2Loss()

# 创建训练集和验证集的数据集对象
train = FSC147Dataset(
    args.data_path,
    args.image_size,
    split='train',
    num_objects=args.num_objects,
    tiling_p=args.tiling_p,
    zero_shot=args.zero_shot
)
val = FSC147Dataset(
    args.data_path,
    args.image_size,
    split='val',
    num_objects=args.num_objects,
    tiling_p=args.tiling_p
)

# 创建训练集和验证集的数据加载器,用于批量加载数据
train_loader = DataLoader(
    train,
    batch_size=args.batch_size,
    drop_last=True,  # 如果批次中的数据不足,则丢弃最后一批
    num_workers=args.num_workers  # 用于数据加载的工作进程数
)
val_loader = DataLoader(
    val,
    batch_size=args.batch_size,
    drop_last=False,  # 验证集保留最后一批数据,即使不足一个批次
    num_workers=args.num_workers
)

功能解释

  • 代码首先检查是否有可用的GPU,并选择使用GPU或CPU。
  • build_model 函数用于根据提供的参数 args 构建模型,并将模型移动到选定的设备上。
  • 通过遍历模型的参数,将参数分配到 backbone_params 和 non_backbone_params 字典中,以便分别为它们设置不同的优化器参数。
  • 使用 optim.AdamW 创建优化器,它是一种使用Adam算法的优化器,带有一个权重衰减项。
  • 使用 StepLR 作为学习率调度器,它会在指定的轮次减少学习率。
  • 如果需要恢复训练,代码将加载先前保存的检查点,并恢复模型、优化器和学习率调度器的状态。
  • ObjectNormalizedL2Loss 用作训练过程中的损失函数。
  • FSC147Dataset 类用于创建训练集和验证集的数据集对象。
  • DataLoader 用于创建数据加载器,它将数据集封装成批次进行加载,以提高数据加载的效率。

整体而言,这段代码负责设置训练流程的初始化步骤,包括设备选择、模型构建、参数分配、优化器和学习率调度器的创建、训练和验证数据集的加载,以及训练前的检查点恢复。

检查点有什么用?

import torch
import torch.nn as nn
from tqdm import tqdm
from time import perf_counter

# 假设其他必要的模块已经导入

# 训练模型的轮次
for epoch in tqdm(range(start_epoch + 1, args.epochs + 1)):
    # 初始化训练和验证的损失和绝对误差
    train_loss = torch.tensor(0.0).to(device)
    val_loss = torch.tensor(0.0).to(device)
    aux_train_loss = torch.tensor(0.0).to(device)
    aux_val_loss = torch.tensor(0.0).to(device)
    train_ae = torch.tensor(0.0).to(device)
    val_ae = torch.tensor(0.0).to(device)

    # 记录训练开始的时间
    start = perf_counter()
    model.train()  # 设置模型为训练模式
    for img, bboxes, density_map in tqdm(train_loader):
        # 将数据移动到设备上(GPU或CPU)
        img = img.to(device)
        bboxes = bboxes.to(device)
        density_map = density_map.to(device)

        # 清除优化器的梯度
        optimizer.zero_grad()
        # 前向传播,获取模型的输出和辅助输出
        out, aux_out = model(img, bboxes)

        # 在不计算梯度的情况下,计算批次中对象的数量
        with torch.no_grad():
            num_objects = density_map.sum()

        # 计算主要损失和辅助损失
        main_loss = criterion(out, density_map, num_objects)
        aux_loss = sum([
            args.aux_weight * criterion(aux, density_map, num_objects) for aux in aux_out
        ])
        # 总损失是主要损失和辅助损失的和
        loss = main_loss + aux_loss
        # 反向传播
        loss.backward()
        # 如果设置了最大梯度范数,进行梯度裁剪
        if args.max_grad_norm > 0:
            nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        # 更新优化器参数
        optimizer.step()

        # 累加训练损失和绝对误差
        train_loss += main_loss * img.size(0)
        aux_train_loss += aux_loss * img.size(0)
        train_ae += torch.abs(
            density_map.flatten(1).sum(dim=1) - out.flatten(1).sum(dim=1)
        ).sum()

    # 设置模型为评估模式
    model.eval()
    with torch.no_grad():
        # 验证过程
        for img, bboxes, density_map in val_loader:
            img = img.to(device)
            bboxes = bboxes.to(device)
            density_map = density_map.to(device)
            out, aux_out = model(img, bboxes)

            num_objects = density_map.sum()

            main_loss = criterion(out, density_map, num_objects)
            aux_loss = sum([
                args.aux_weight * criterion(aux, density_map, num_objects) for aux in aux_out
            ])
            loss = main_loss + aux_loss

            # 累加验证损失和绝对误差
            val_loss += main_loss * img.size(0)
            aux_val_loss += aux_loss * img.size(0)
            val_ae += torch.abs(
                density_map.flatten(1).sum(dim=1) - out.flatten(1).sum(dim=1)
            ).sum()

功能解释

  • 代码使用 tqdm 库来显示进度条,提供训练和验证过程中的可视化反馈。
  • 在每个训练周期开始时,初始化损失和绝对误差的累积变量。
  • 使用 perf_counter 来记录训练周期的开始时间,可用于计算训练周期的持续时间。
  • 在训练循环中,将数据批次加载到指定的设备上(GPU或CPU),并执行模型的前向传播。
  • 使用 with torch.no_grad() 块来在不计算梯度的情况下计算批次中对象的数量,这有助于减少内存消耗。
  • 计算主要损失和辅助损失,并对总损失执行反向传播来更新模型的权重。
  • 如果设置了最大梯度范数,使用 clip_grad_norm_ 函数来防止梯度爆炸。
  • 在验证循环中,模型在评估模式下运行,不计算梯度,并计算验证损失和绝对误差。
  • 训练和验证的损失以及绝对误差在各自的循环中累积,并在循环结束后用于性能评估。

整体而言,这段代码提供了一个完整的训练和验证流程,包括了损失计算、反向传播、梯度裁剪、性能评估等关键步骤,是深度学习模型训练中的标准实践。

辅助损失是什么?

  • 9
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值