任务驱动:由于参加挑战赛,官方对提交代码没有限制,仅希望对已有方法进行改进即可。因此选用了nnUNetv2作为提交,同时希望训练自己设计的网络。但是网上提供的点比较散,这里以官方文档为主线,同时吸纳了其他作者的思路,最终完成了修改。
nnUNetv2 修改训练网络通关实录
by Gunhild
0.nnUNetv2 安装
不过多赘述,贴一个csdn教程:链接: nnU-Net v2的环境配置到训练自己的数据集(详细步骤)
1.打开nnUNetv2 的文档,看看作者推荐的修改方式:`
点开extending_nnunet.md 文件,我们可以看到作者列出了以下几行:
这里我们用第一种,这就要求我们找到:函数build_network_architecture和nnUNetTrainerNoDeepSupervision.py两个文件,那么他们在哪个位置呢?
2.寻找build_network_architecture和nnUNetTrainerNoDeepSupervision.py:
首先来看看build_network_architecture函数在哪,它位于图中所示路径里:
这是build_network_architecture函数的基本内容,但是我们发现,它的函数功能已经被写死在运行计划里了,nnunet的运行计划在数据预处理时便被写好。
暂且搁置这个函数,我们来看看nnUNetTrainerNoDeepSupervision.py,它给出了非深监督网络的基本结构写法。
需要说明的是,这个教程只会给出最简单的非深监督实现方法,因此对于深监督网络结构的写法,我后面如果有空就单独出一期
东西很长,但是我们只需要关注 initialize 这个函数,它负责初始化nnunet的网络模型架构
问题来了:怎么修改成自己的网络结构呢?
这里我另外写了一个生成网络结构的函数,称为create_model.py,
如图中红框所示:红框的下方有行被注释掉的网络模型生成函数,可以借此导入自己的网络模型。
如此一来,我们来看看create_model.py文件需要怎么写:
import torch
from nnunetv2.training.nnUNetTrainer.KG_VNet import KG_VNet
from nnunetv2.training.nnUNetTrainer.KG_HNet import KG_HNet
def create_model(in_channel=3, num_class=2, name='KG_VNet'):
if name == 'KG_VNet':
print("="*20)
print("now use our kgvnet")
print("=" * 20)
return KG_VNet(in_channel=in_channel, out_channel=num_class, pretrained=True, backbone_num='34', att_name='ca')
elif name == 'KG_HNet':
print("="*20)
print("now use our kghnet")
print("=" * 20)
return KG_HNet(in_channel=in_channel, out_channel=num_class, pretrained=False, backbone_num='34', att_name='ca')
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_model(in_channel=4, num_class=2, name='KG_HNet').to(device)
input = torch.randn([1, 4, 128, 128, 128]).to(device)
print(model(input).shape)
如上是我的一种书写模式,出于方便我没有写得很规范(if没有else,容易报错)
为了书写简洁,我将网络结构单独剥离出,由import的方式写进去,这里大家import自己的网络模型函数就可以了。
除了修改以上文件,我们也要把模型推理的文件处理好,如下:
将里面的network替换成create_model
3.修改完,大功告成,运行我们自己的nnUNet:
现在,我们的网络结构修改好了,该运行了,运行的命令如下:
nnUNetv2_train $dataset_num 3d_fullres $fold -tr nnUNetTrainerNoDeepSupervision
dataset_num就是第几个数据集,fold就是跑第几折的意思, -tr nnUNetTrainerNoDeepSupervision 表示使用非深监督方法运行nnunet,这样我们的网络就跑起来了。
4.问题解决
4.1 nnUNet 中 import 和 from ** import ** 的注意事项
想必整篇文章看下来,大家看到的代码中,nnUNet对于自己的import已经是设定好绝对初始路径了,我们可以看看:
我们在安装nnUNet的必做事项中,就把nnUNet的绝对路径设为了从nnunetv2开始。所以我们注入自己的网络结构代码时,也要遵循这个路径:再贴一张我的代码,作为一个参照。
4.2 nnUNet 非深监督py文件仅存在init函数
大家好,果然如我所料,漏掉了一个关键问题,感谢这位网友@chand_ler 指出
如图,原始py文件仅有init的函数,我们需要补足它。
这里我参照了这个博主的写法, UUNet训练自己写的网络
它的写法很全,大家直接参考即可。
这里预留问题储备框格,若遇到一些常见问题会写在这里。