【nnUNet v2版本 如何训练自己设计的网络】

任务驱动:由于参加挑战赛,官方对提交代码没有限制,仅希望对已有方法进行改进即可。因此选用了nnUNetv2作为提交,同时希望训练自己设计的网络。但是网上提供的点比较散,这里以官方文档为主线,同时吸纳了其他作者的思路,最终完成了修改。

nnUNetv2 修改训练网络通关实录

by Gunhild

0.nnUNetv2 安装

不过多赘述,贴一个csdn教程:链接: nnU-Net v2的环境配置到训练自己的数据集(详细步骤)

1.打开nnUNetv2 的文档,看看作者推荐的修改方式:`

extending_nnunet

点开extending_nnunet.md 文件,我们可以看到作者列出了以下几行:

作者给出的两种网络修改方法
这里我们用第一种,这就要求我们找到:函数build_network_architecture和nnUNetTrainerNoDeepSupervision.py两个文件,那么他们在哪个位置呢?

2.寻找build_network_architecturennUNetTrainerNoDeepSupervision.py

首先来看看build_network_architecture函数在哪,它位于图中所示路径里:

build_network_architecture
这是build_network_architecture函数的基本内容,但是我们发现,它的函数功能已经被写死在运行计划里了,nnunet的运行计划在数据预处理时便被写好。

暂且搁置这个函数,我们来看看nnUNetTrainerNoDeepSupervision.py,它给出了非深监督网络的基本结构写法。
需要说明的是,这个教程只会给出最简单的非深监督实现方法,因此对于深监督网络结构的写法,我后面如果有空就单独出一期
所在路径

东西很长,但是我们只需要关注 initialize 这个函数,它负责初始化nnunet的网络模型架构

在这里插入图片描述

问题来了:怎么修改成自己的网络结构呢?
这里我另外写了一个生成网络结构的函数,称为create_model.py
如图中红框所示:红框的下方有行被注释掉的网络模型生成函数,可以借此导入自己的网络模型。

在这里插入图片描述

如此一来,我们来看看create_model.py文件需要怎么写:

import torch
from nnunetv2.training.nnUNetTrainer.KG_VNet import KG_VNet
from nnunetv2.training.nnUNetTrainer.KG_HNet import KG_HNet

def create_model(in_channel=3, num_class=2, name='KG_VNet'):
    if name == 'KG_VNet':
        print("="*20)
        print("now use our kgvnet")
        print("=" * 20)
        return KG_VNet(in_channel=in_channel, out_channel=num_class, pretrained=True, backbone_num='34', att_name='ca')
    elif name == 'KG_HNet':
        print("="*20)
        print("now use our kghnet")
        print("=" * 20)
        return KG_HNet(in_channel=in_channel, out_channel=num_class, pretrained=False, backbone_num='34', att_name='ca')


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = create_model(in_channel=4, num_class=2, name='KG_HNet').to(device)
    input = torch.randn([1, 4, 128, 128, 128]).to(device)
    print(model(input).shape)

如上是我的一种书写模式,出于方便我没有写得很规范(if没有else,容易报错
为了书写简洁,我将网络结构单独剥离出,由import的方式写进去,这里大家import自己的网络模型函数就可以了。

除了修改以上文件,我们也要把模型推理的文件处理好,如下:
在这里插入图片描述

将里面的network替换成create_model

在这里插入图片描述

3.修改完,大功告成,运行我们自己的nnUNet:

现在,我们的网络结构修改好了,该运行了,运行的命令如下:

nnUNetv2_train $dataset_num 3d_fullres $fold -tr nnUNetTrainerNoDeepSupervision

dataset_num就是第几个数据集,fold就是跑第几折的意思, -tr nnUNetTrainerNoDeepSupervision 表示使用非深监督方法运行nnunet,这样我们的网络就跑起来了。

4.问题解决


4.1 nnUNet 中 import 和 from ** import ** 的注意事项

想必整篇文章看下来,大家看到的代码中,nnUNet对于自己的import已经是设定好绝对初始路径了,我们可以看看:
在这里插入图片描述

我们在安装nnUNet的必做事项中,就把nnUNet的绝对路径设为了从nnunetv2开始。所以我们注入自己的网络结构代码时,也要遵循这个路径:再贴一张我的代码,作为一个参照。

在这里插入图片描述

4.2 nnUNet 非深监督py文件仅存在init函数

大家好,果然如我所料,漏掉了一个关键问题,感谢这位网友@chand_ler 指出

如图,原始py文件仅有init的函数,我们需要补足它。
在这里插入图片描述
这里我参照了这个博主的写法, UUNet训练自己写的网络
它的写法很全,大家直接参考即可。


这里预留问题储备框格,若遇到一些常见问题会写在这里。

评论 24
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值