一、项目概述
nnUNet(neural network Universal Network)是一款基于深度学习的医学图像分割开源框架,核心定位是为医学影像分割任务提供通用化、自动化、高性能的解决方案。该项目由医学影像与深度学习领域研究者开发,初衷是解决不同医学影像分割任务中 “模型适配性差、参数调优复杂、工程落地成本高” 的痛点,无需用户具备深厚的深度学习工程经验,即可快速适配不同模态、不同器官的分割需求。
项目开源后迅速成为医学影像分割领域的 “标杆工具”,被广泛应用于学术研究与临床前研究场景,其核心设计理念 “数据驱动的自适应配置” 已成为医学图像分割工具的重要设计范式。
二、项目取得的成绩
- 学术竞赛表现:在多个国际顶级医学影像分割竞赛(如 BraTS、MSD Challenge)中持续取得 Top 排名,成为竞赛中最常用的基准框架之一。
- 行业认可度:被超过 1000 篇 SCI 论文引用,涵盖肿瘤分割、器官分割、病灶检测等多个医学影像方向,成为医学深度学习领域的 “标准工具库”。
- 落地适配能力:已成功适配 CT、MRI、PET 等多种医学影像模态,支持脑、肺、肝、肾等 20 + 器官 / 病灶的分割任务,无需大量定制化开发。
- 性能标杆:在公开医学影像数据集(如 LiTS、Pancreas-CT)上,分割准确率(Dice 系数)普遍达到 0.85 以上,部分任务突破 0.95,远超传统分割方法。
三、技术栈详解
核心技术栈
| 技术类别 | 具体技术 / 工具 | 核心作用 |
|---|---|---|
| 编程语言 | Python 3.7+ | 核心开发语言,兼顾开发效率与生态完整性 |
| 深度学习框架 | PyTorch 1.6+ | 模型构建、训练与推理的核心框架,支持动态图与分布式训练 |
| 数据处理 | NumPy、SciPy、SimpleITK | 医学影像读取(DICOM/NIfTI 格式)、数据预处理(重采样、归一化) |
| 工程化工具 | Numba | 加速数值计算(如指标计算、数据增强) |
| 可视化工具 | Matplotlib、Seaborn | 分割结果可视化、训练曲线监控 |
| 分布式训练 | PyTorch DistributedDataParallel | 多 GPU 并行训练,提升大规模数据训练效率 |
核心算法技术
- 网络架构:基于 U-Net 及其变体(U-Net++、3D U-Net),采用编码器 - 解码器结构,引入残差连接与密集连接提升特征传播能力。
- 自适应配置策略:自动根据数据集特性(图像尺寸、模态数、类别数)调整网络参数(卷积核大小、网络深度、 batch size)。
- 数据预处理 pipeline:包括强度归一化(z-score/percentile)、重采样(基于体素间距统一尺度)、标签处理(类别平衡)等自动化流程。
- 训练策略:混合精度训练、学习率余弦退火、早停机制、交叉验证(k-fold),提升模型泛化能力。
- 后处理技术:连通区域分析、孔洞填充,解决分割结果中的 “孤立点” 与 “空洞” 问题。
四、项目优势与劣势
核心优势
- 通用性极强:无需修改模型结构,仅通过数据格式适配,即可支持不同医学影像模态、不同分割任务,降低使用门槛。
- 自动化程度高:数据预处理、网络配置、训练参数调优均实现自动化,非专业开发者也能快速上手。
- 性能表现优异:基于数据驱动的配置策略,模型能自适应数据集特性,分割准确率与鲁棒性远超同类工具。
- 工程化成熟:代码结构清晰、文档完善,支持多 GPU 训练、断点续训、结果自动评估,具备工业级落地潜力。
- 生态完善:兼容主流医学影像格式(DICOM、NIfTI),支持与医学影像处理软件(如 3D Slicer)联动。
主要劣势
- 灵活性不足:自适应配置策略限制了用户对模型结构的深度定制,难以满足特殊场景(如小样本、极端不平衡数据)的个性化需求。
- 计算资源依赖:3D U-Net 架构对硬件要求较高,训练大规模 3D 影像(如全脑 MRI)需多 GPU 支持,单机单卡训练速度较慢。
- 非医学场景适配差:设计初衷聚焦医学影像,对自然图像分割等非医学场景的支持不足,数据预处理 pipeline 难以直接复用。
- 实时性欠缺:推理阶段对大尺寸影像需分块处理,实时性表现一般,难以满足临床实时分割的需求。
- 依赖专业数据格式:对医学影像格式(如 DICOM)的依赖较强,普通用户需额外学习数据格式转换,增加使用成本。
五、典型使用场景
- 学术研究:医学影像分割相关的论文实验、竞赛参与,快速构建基准模型并与新方法对比。
- 临床前研究:医院 / 科研机构的临床前数据分析,如肿瘤体积测量、器官形态分析等辅助研究。
- 多模态影像分割:需要处理 CT、MRI 等多种模态数据的场景,如脑肿瘤(BraTS 数据集)、肝脏肿瘤(LiTS 数据集)分割。
- 小样本医学影像分割:利用 nnUNet 的自适应数据增强与正则化策略,在样本量有限的场景(如罕见病影像分割)中快速构建有效模型。
- 医学影像分割工具开发:作为核心分割模块,集成到医疗 AI 产品中,加速产品落地(如辅助诊断系统、影像分析平台)。
- 教学场景:医学深度学习、医学影像处理课程的实践教学,帮助学生快速理解分割模型的工程实现逻辑。
六、代码结构与核心执行步骤
1. 代码结构(核心目录)
plaintext
nnUNet/
├── nnunet/
│ ├── configuration/ # 配置模块:自适应配置生成、参数管理
│ ├── data_loading/ # 数据加载:影像读取、数据增强、batch生成
│ ├── evaluation/ # 评估模块:Dice系数、Hausdorff距离等指标计算
│ ├── inference/ # 推理模块:模型预测、后处理
│ ├── networks/ # 网络模块:U-Net变体、损失函数定义
│ ├── training/ # 训练模块:训练循环、优化器配置
│ └── utilities/ # 工具函数:影像处理、文件操作、日志管理
├── examples/ # 示例代码:快速上手教程
├── tests/ # 单元测试:模块功能验证
└── setup.py # 安装配置
2. 核心执行步骤
(1)数据准备阶段
- 数据格式转换:将原始医学影像(DICOM)转换为 NIfTI 格式,按 “图像 - 标签” 成对组织。
- 数据目录结构化:遵循 nnUNet 标准目录结构(raw_data、processed_data、results),便于框架自动识别。
(2)数据预处理阶段
- 数据探索:自动分析数据集的图像尺寸、体素间距、强度分布、类别分布等特性。
- 自适应预处理:根据数据特性自动执行重采样(统一体素间距)、强度归一化、标签编码。
- 数据增强:生成训练集的增强样本(随机翻转、旋转、缩放、噪声添加),提升模型泛化能力。
(3)模型配置阶段
- 网络配置生成:根据数据维度(2D/3D)、模态数、类别数,自动选择最优网络架构(2D U-Net/3D U-Net)。
- 训练参数配置:自动设置 batch size、学习率、训练轮数、优化器(AdamW)等参数。
(4)模型训练阶段
- 交叉验证划分:将数据集按 k-fold(默认 5 折)划分,避免过拟合。
- 训练循环执行:执行前向传播(图像输入→特征提取→分割预测)、损失计算(Dice 损失 + 交叉熵损失)、反向传播(参数更新)。
- 模型保存:保存每折训练的最优模型(基于验证集 Dice 系数)。
(5)推理与后处理阶段
- 模型加载:加载训练好的最优模型权重。
- 批量预测:对测试集图像进行分割预测,支持分块推理(处理大尺寸影像)。
- 后处理:通过连通区域分析去除孤立小病灶,填充分割结果中的空洞。
- 结果输出:将分割结果保存为 NIfTI 格式,支持可视化与指标评估。
3. 核心执行时序图
plaintext
┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐
│ 数据准备 │────▶│ 数据预处理 │────▶│ 模型配置 │────▶│ 模型训练 │────▶│ 推理后处理 │
└───────────┘ └───────────┘ └───────────┘ └───────────┘ └───────────┘
│ │ │ │ │
▼ ▼ ▼ ▼ ▼
┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐
│格式转换/ │ ┌───────────┐ │自动选择 │ │k-fold交叉 │ │分块预测/ │
│目录结构化 │ │重采样/归一化/│ │网络/参数 │ │验证/模型保存│ │后处理/结果输出│
│ │ │数据增强 │ │ │ │ │ │ │
└───────────┘ └───────────┘ └───────────┘ └───────────┘ └───────────┘
七、开发示例代码
以下示例代码实现 nnUNet 的核心流程(简化版),涵盖数据准备、模型定义、训练与推理的基础功能。
1. 环境准备
bash
运行
# 安装依赖
pip install torch numpy scipy simpleitk numba matplotlib
2. 数据准备(简化版)
python
运行
import os
import SimpleITK as sitk
import numpy as np
def prepare_nnunet_data(raw_data_dir, output_dir):
"""
简化版数据准备:将DICOM格式转换为nnUNet标准NIfTI格式
"""
# 创建nnUNet标准目录结构
os.makedirs(os.path.join(output_dir, "imagesTr"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "labelsTr"), exist_ok=True)
# 遍历原始DICOM数据
for patient_id in os.listdir(raw_data_dir):
patient_dir = os.path.join(raw_data_dir, patient_id)
if not os.path.isdir(patient_dir):
continue
# 读取DICOM图像
img_reader = sitk.ImageSeriesReader()
img_filenames = img_reader.GetGDCMSeriesFileNames(patient_dir)
img_reader.SetFileNames(img_filenames)
img = img_reader.Execute()
# 读取标签(假设标签为单独的DICOM序列)
label_dir = os.path.join(patient_dir, "label")
label_filenames = img_reader.GetGDCMSeriesFileNames(label_dir)
img_reader.SetFileNames(label_filenames)
label = img_reader.Execute()
# 保存为NIfTI格式(nnUNet标准命名:patient_id_0000.nii.gz,0000表示模态)
sitk.WriteImage(img, os.path.join(output_dir, "imagesTr", f"{patient_id}_0000.nii.gz"))
sitk.WriteImage(label, os.path.join(output_dir, "labelsTr", f"{patient_id}.nii.gz"))
print("数据准备完成,输出目录:", output_dir)
# 调用示例
prepare_nnunet_data(raw_data_dir="./raw_dicom", output_dir="./nnunet_data")
3. 简化版 U-Net 模型定义(核心网络模块)
python
运行
import torch
import torch.nn as nn
class SimpleUNet(nn.Module):
"""简化版3D U-Net,模拟nnUNet核心网络结构"""
def __init__(self, in_channels=1, num_classes=2):
super(SimpleUNet, self).__init__()
# 编码器(下采样)
self.enc1 = self.conv_block(in_channels, 64)
self.enc2 = self.conv_block(64, 128)
self.enc3 = self.conv_block(128, 256)
# 解码器(上采样)
self.dec1 = self.conv_block(256, 128)
self.dec2 = self.conv_block(128, 64)
self.dec3 = self.conv_block(64, num_classes)
# 池化与上采样
self.pool = nn.MaxPool3d(2, 2)
self.upconv = nn.ConvTranspose3d(256, 128, 2, stride=2)
self.final_conv = nn.Conv3d(num_classes, num_classes, 1)
def conv_block(self, in_channels, out_channels):
"""卷积块:Conv3d + BatchNorm + ReLU"""
return nn.Sequential(
nn.Conv3d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True),
nn.Conv3d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
# 编码器
x1 = self.enc1(x)
x2 = self.pool(x1)
x2 = self.enc2(x2)
x3 = self.pool(x2)
x3 = self.enc3(x3)
# 解码器
x = self.upconv(x3)
x = torch.cat([x, x2], dim=1) # 跳跃连接
x = self.dec1(x)
x = self.upconv(x)
x = torch.cat([x, x1], dim=1) # 跳跃连接
x = self.dec2(x)
x = self.dec3(x)
out = self.final_conv(x)
return out
# 模型实例化
model = SimpleUNet(in_channels=1, num_classes=2)
print("模型结构:", model)
4. 简化版训练流程
python
运行
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
# 自定义数据集(简化版)
class MedicalDataset(Dataset):
def __init__(self, data_dir):
self.image_dir = os.path.join(data_dir, "imagesTr")
self.label_dir = os.path.join(data_dir, "labelsTr")
self.patients = [f.split("_0000.nii.gz")[0] for f in os.listdir(self.image_dir)]
def __len__(self):
return len(self.patients)
def __getitem__(self, idx):
patient_id = self.patients[idx]
# 读取NIfTI图像
img = sitk.ReadImage(os.path.join(self.image_dir, f"{patient_id}_0000.nii.gz"))
img = sitk.GetArrayFromImage(img).astype(np.float32)[None, ...] # (C, D, H, W)
# 读取标签
label = sitk.ReadImage(os.path.join(self.label_dir, f"{patient_id}.nii.gz"))
label = sitk.GetArrayFromImage(label).astype(np.longlong)
return torch.from_numpy(img), torch.from_numpy(label)
# 训练配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = "./nnunet_data"
dataset = MedicalDataset(data_dir)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
# 模型、损失函数、优化器
model = SimpleUNet().to(device)
criterion = nn.CrossEntropyLoss() # 结合Dice损失效果更优,此处简化
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
# 训练循环
def train_epoch(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0.0
for imgs, labels in dataloader:
imgs, labels = imgs.to(device), labels.to(device)
# 前向传播
outputs = model(imgs)
loss = criterion(outputs, labels)
# 反向传播与优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# 执行训练
num_epochs = 10
for epoch in range(num_epochs):
train_loss = train_epoch(model, dataloader, criterion, optimizer, device)
print(f"Epoch {epoch+1}/{num_epochs},
2万+

被折叠的 条评论
为什么被折叠?



