nnUnet(代码)-训练部分

这篇博客深入解析了nnUNet训练过程,包括训练计划的获取、数据增强参数初始化、五折交叉验证、数据加载、网络与优化器初始化。nnUNetTrainer版本一采用DC_and_CE_loss作为损失函数,使用Adam优化器和基于损失平均值的学习率调度。版本二引入了深监督,通过权重调整强化损失函数,并改用SGD优化器及自定义学习率下降策略。此外,还讨论了数据增强参数的变化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

学习目标:逐步分析nnunet训练部分

学习内容:training部分

· 拿到训练plans(计划)
· 初始化数据增强参数
· 采用五折交叉验证
· dataset与dataloader/数据加载过程
· 初始化网络
· 初始化优化器与学习率函数

1.nnUNetTrainer(版本一的训练方法)

··· 损失函数:

self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {})

··· 优化器与学习率函数:
优化器用adam
学习率的调整是用的损失函数的加权平均值来判断是否变动的方法

    def initialize_optimizer_and_scheduler(self):
        assert self.network is not None, "self.initialize_network must be called first"
        self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
                                          amsgrad=True)
        self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2,
                                                           patience=self.lr_scheduler_patience,
                                                           verbose=True, threshold=1e-3,
                                                           threshold_mode="abs")
# 学习率函数设置
self.train_loss_MA_alpha = 0.93  # alpha * old + (1-alpha) * new

    def update_train_loss_MA(self):
        if self.train_loss_MA is None:
            self.train_loss_MA = self.all_tr_losses[-1]
        else:
            self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \
                                 self.all_tr_losses[-1]
# lr scheduler is updated with moving average val loss. should be more robust
self.lr_scheduler.step(self.train_loss_MA)

2.nnUNetTrainerV2(版本二的训练方法)

··· 加强了损失函数(深监督):
还是原来损失,但是添加了一个策略:给每层的损失加一个权重,分辨率越高的权重越大,简单说就是针对中间隐藏层特征透明度不高以及深层网络中浅层以及中间网络难以训练的问题。

################# 封装损失函数进入深度学习(深监督) ############
        # 需要知道网络深度
        # net_numpool = len(self.plans['pool_op_kernel_sizes'])

        # 我们给每个输出一个权重,该权重随着分辨率的降低呈指数递减(除以2)
        # 这使得更高的分辨率输出在损失中有更大的权重
        weights = np.array([1 / (2 ** i) for i in range(self.net_numpool)])

        # 我们不使用最低的2个输出。标准化权重,使其总和为1
        mask = np.array([True] + [True if i < self.net_numpool - 1 else False for i in range(1, self.net_numpool)])
        weights[~mask] = 0
        weights = weights / weights.sum()
        self.ds_loss_weights = weights

        # 封装损失函数
        self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)

··· 重写了优化器与学习率函数
采用SGD与自定义的学习率下降函数

    def initialize_optimizer_and_scheduler(self):
        assert self.network is not None, "self.initialize_network must be called first"
        self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
                                         momentum=0.99, nesterov=True)
        self.lr_scheduler = None
 def maybe_update_lr(self, epoch=None):

        if epoch is None:
            ep = self.epoch + 1
        else:
            ep = epoch
        self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9)
        
def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9):
    return initial_lr * (1 - epoch / max_epochs)**exponent

··· 重写了数据增强参数

3.后面还有DP等等三四个版本,是基于版本二改变的,主要是通过混合精度进行训练增加训练速度

### 关于nnUNetV2训练教程及相关常见问题 #### nnUNetV2的训练流程概述 nnUNet 是一种自动化的医学图像分割框架,其 V2 版本引入了许多改进功能。为了成功运行 nnUNetV2 的训练过程,通常需要完成以下几个核心环节:环境配置、数据准备以及模型训练与评估。 在环境配置方面,官方文档提供了详细的指导说明[^3],建议按照指南安装必要的依赖库并设置路径变量以便顺利执行脚本命令。对于初学者而言,可以参考一篇详尽的博文来帮助理解整个工作流[^2]。 当涉及到具体操作时,比如定义不同的 fold 来实现交叉验证机制,在目录结构 `nnUNetTrainer__nnUNetPlans__2d` 下会自动生成各 folds 对应的日志文件夹用于存储每次迭代的结果信息[^1]。这些日志不仅有助于监控当前进度状态而且便于后续分析性能瓶颈所在位置。 另外值得注意的是如果想要定制化设计新的神经网络架构,则需深入研究源码内部逻辑并通过继承原有类或者重写部分函数的方式达成目标。 ```bash # 启动默认参数下的二维训练任务示例代码片段如下所示: nnUNet_train 2d Task001_BrainTumour nnUNetPlans --preprocess_only ``` 上述命令展示了如何利用预处理好的数据集启动一项基于平面视图的任务实例,并指定了计划名称作为输入之一。 #### 常见错误排查技巧 以下是几个可能遇到的技术难题及其解决方案: - **内存溢出**:调整批量大小(batch size),减少GPU显存占用率;优化图片尺寸裁剪策略降低计算复杂度。 - **收敛缓慢或不稳定**:检查学习率设定是否合理适当;尝试更换损失函数类型观察效果差异变化情况。 - **预测精度低下**:重新审视标注质量是否存在偏差现象;增加样本数量扩充多样性覆盖范围更广场景特征表达能力更强模型泛化水平更高。 通过以上方法能够有效提升开发效率缩短调试周期最终获得满意成果!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值