nnUnet(代码)-训练部分

学习目标:逐步分析nnunet训练部分

学习内容:training部分

· 拿到训练plans(计划)
· 初始化数据增强参数
· 采用五折交叉验证
· dataset与dataloader/数据加载过程
· 初始化网络
· 初始化优化器与学习率函数

1.nnUNetTrainer(版本一的训练方法)

··· 损失函数:

self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {})

··· 优化器与学习率函数:
优化器用adam
学习率的调整是用的损失函数的加权平均值来判断是否变动的方法

    def initialize_optimizer_and_scheduler(self):
        assert self.network is not None, "self.initialize_network must be called first"
        self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
                                          amsgrad=True)
        self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2,
                                                           patience=self.lr_scheduler_patience,
                                                           verbose=True, threshold=1e-3,
                                                           threshold_mode="abs")
# 学习率函数设置
self.train_loss_MA_alpha = 0.93  # alpha * old + (1-alpha) * new

    def update_train_loss_MA(self):
        if self.train_loss_MA is None:
            self.train_loss_MA = self.all_tr_losses[-1]
        else:
            self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \
                                 self.all_tr_losses[-1]
# lr scheduler is updated with moving average val loss. should be more robust
self.lr_scheduler.step(self.train_loss_MA)

2.nnUNetTrainerV2(版本二的训练方法)

··· 加强了损失函数(深监督):
还是原来损失,但是添加了一个策略:给每层的损失加一个权重,分辨率越高的权重越大,简单说就是针对中间隐藏层特征透明度不高以及深层网络中浅层以及中间网络难以训练的问题。

################# 封装损失函数进入深度学习(深监督) ############
        # 需要知道网络深度
        # net_numpool = len(self.plans['pool_op_kernel_sizes'])

        # 我们给每个输出一个权重,该权重随着分辨率的降低呈指数递减(除以2)
        # 这使得更高的分辨率输出在损失中有更大的权重
        weights = np.array([1 / (2 ** i) for i in range(self.net_numpool)])

        # 我们不使用最低的2个输出。标准化权重,使其总和为1
        mask = np.array([True] + [True if i < self.net_numpool - 1 else False for i in range(1, self.net_numpool)])
        weights[~mask] = 0
        weights = weights / weights.sum()
        self.ds_loss_weights = weights

        # 封装损失函数
        self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)

··· 重写了优化器与学习率函数
采用SGD与自定义的学习率下降函数

    def initialize_optimizer_and_scheduler(self):
        assert self.network is not None, "self.initialize_network must be called first"
        self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
                                         momentum=0.99, nesterov=True)
        self.lr_scheduler = None
 def maybe_update_lr(self, epoch=None):

        if epoch is None:
            ep = self.epoch + 1
        else:
            ep = epoch
        self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9)
        
def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9):
    return initial_lr * (1 - epoch / max_epochs)**exponent

··· 重写了数据增强参数

3.后面还有DP等等三四个版本,是基于版本二改变的,主要是通过混合精度进行训练增加训练速度

  • 1
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
NNUNet的多尺度训练过程包括以下几个步骤: 1. 数据预处理:将训练数据按照不同的尺度进行缩放和裁剪,以生成多个不同尺度的图像。在NNUNet中,通常将原始图像缩放到不同的尺度,例如1/2、1/4、1/8等。 2. 模型训练:使用多个尺度的图像进行模型训练。在NNUNet中,通常使用不同的尺度对模型进行多次训练,以使模型能够适应不同的输入图像。在每个尺度上,NNUNet都会使用跳跃连接将编码器和解码器之间的特征进行连接,以提高模型的性能和鲁棒性。 3. 模型融合:将多个训练过的模型的输出进行融合,以生成最终的分割结果。在NNUNet中,通常使用投票、平均、加权平均等方法对多个模型的输出进行融合。在融合过程中,NNUNet还会对不同尺度的分割结果进行加权,以使模型能够更好地适应不同尺度的图像。 4. 预测过程:在测试时,NNUNet会对输入图像进行多尺度预测,以生成不同尺度的分割结果。然后,NNUNet会将不同尺度的分割结果进行融合,以生成最终的分割结果。 多尺度训练可以带来以下几个优点: 1. 提高模型的泛化性能和鲁棒性:多尺度训练可以使模型能够适应不同尺度和大小的图像,从而提高模型的泛化性能和鲁棒性。 2. 提高模型的精度:使用多个尺度的图像进行训练可以增加训练数据的多样性,从而提高模型的精度。 3. 减少过拟合:多尺度训练可以减少模型的过拟合,从而提高模型的泛化能力和鲁棒性。 总之,NNUNet的多尺度训练过程可以使模型更好地适应不同的输入图像,从而提高模型的性能和鲁棒性。同时,NNUNet还使用了其他技术来进一步提高模型的性能和鲁棒性,如跳跃连接、数据增强和集成学习等。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值