目录
-
为什么选择 UNet + CT 分割
-
数据来源(推荐数据集与许可)
-
环境与依赖(conda / pip)
-
数据预处理(核心步骤 + 代码)
-
模型:3D U-Net(MONAI 实现) + 损失与度量
-
训练脚本(含 patch 采样、混合精度)
-
推理与后处理(例:连通域、最小体积过滤)
-
评估(Dice / IoU / 体积误差)
-
部署(Streamlit Demo 与 ONNX 导出示例)
-
常见问题与调参建议
-
完整项目结构与附件(复制即用)
-
结语与扩展方向
1. 为什么选择 UNet + CT 分割
U-Net 架构自 2015 年推出以来成为医学图像分割的基石:编码器-解码器结构精于提取语义并恢复细节,易于扩展到 3D,是 CT/ MRI 等体积分割的首选之一。arXiv
2. 数据来源
-
Medical Segmentation Decathlon (MSD):包含多种器官/病灶的 3D 分割任务,适合初学者做泛化实验;官方站点与论文说明了数据组织与评估方式。medicaldecathlon.com+1
-
LIDC / LUNA16(肺结节):用于肺结节检测/分割研究(如需结节级别评估可用)。(可根据需要并遵守数据许可下载)
-
本文示例用 MSD 的 Spleen / Lung 或公开镜像小样本做演示(若使用医院 DICOM,注意机构审批与脱敏)。
数据使用须遵守原数据许可与医院伦理(IRB)要求。
3. 环境搭建
conda create -n medseg python=3.10 -y
conda activate medseg
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
pip install monai[all] torchio nibabel pydicom SimpleITK scikit-image opencv-python matplotlib streamlit onnx onnxruntime
-
推荐使用 NVIDIA GPU(至少 12GB,训练 patch-size 较大时建议 24GB)。
-
推荐使用 MONAI(专为医学影像设计,内置常用 transforms、网络与训练工具)。GitHub
4. 数据预处理(CT 专用核心步骤 + 代码)
CT 体素的物理间距(spacing)与 HU 值对训练影响非常大,必须统一处理。
关键步骤
-
读取 DICOM / NIfTI(SimpleITK 或 nibabel)
-
HU 校正与截断(常用窗:
[-1000, 400]或[-1350, 800],视任务) -
重采样(resample)到统一 voxel spacing(例如
1.0 × 1.0 × 1.0 mm) -
肺野提取(阈值 + 连通域,或使用快速网络)
-
存成 NifTI 或 .npz 供训练读取(减少 I/O 时间)
-
生成训练用 patch(patch-based training)
下面给出关键函数(SimpleITK 实现):
# utils/preprocess.py
import SimpleITK as sitk
import numpy as np
def read_dicom_series(folder):
reader = sitk.ImageSeriesReader()
series_IDs = reader.GetGDCMSeriesFileNames(folder)
reader.SetFileNames(series_IDs)
image = reader.Execute()
arr = sitk.GetArrayFromImage(image) # z,y,x
spacing = image.GetSpacing()[::-1] # sitk: x,y,z -> convert to z,y,x
origin = image.GetOrigin()
return arr, spacing, origin
def resample_image(arr, spacing, new_spacing=(1.0,1.0,1.0), is_label=False):
img = sitk.GetImageFromArray(arr)
img.SetSpacing((spacing[2], spacing[1], spacing[0])) # sitk expects x,y,z
orig_size = img.GetSize()
new_size = [
int(np.round(orig_size[i] * (img.GetSpacing()[i] / new_spacing[i])))
for i in range(3)
]
resampler = sitk.ResampleImageFilter()
resampler.SetOutputSpacing(new_spacing)
resampler.SetSize(new_size)
resampler.SetInterpolator(sitk.sitkNearestNeighbor if is_label else sitk.sitkLinear)
resampled = resampler.Execute(img)
return sitk.GetArrayFromImage(resampled)
def hu_clip_normalize(arr, hu_min=-1000, hu_max=400):
arr = np.clip(arr, hu_min, hu_max)
arr = (arr - hu_min) / (hu_max - hu_min) # 0-1
return arr.astype(np.float32)
Tip:预处理最好一次做完并存成 NIfTI(或 torch.save 的 tensor),训练读取速度更快。
5. 模型:3D U-Net(MONAI 实现)与损失函数
使用 MONAI 可以非常简洁地构建 3D U-Net:
# model.py
from monai.networks.nets import UNet
def get_unet(in_channels=1, out_channels=1, channels=(16,32,64,128,256)):
model = UNet(
dimensions=3,
in_channels=in_channels,
out_channels=out_channels,
channels=channels,
strides=(2,2,2,2),
num_res_units=2,
norm='batch'
)
return model
损失函数与度量
常用组合:DiceLoss + BCEWithLogitsLoss,指标使用 DiceCoefficient、IoU。
import monai
loss = monai.losses.DiceLoss(sigmoid=True)
bce = torch.nn.BCEWithLogitsLoss()
def mixed_loss(pred, target, alpha=0.5):
return alpha * loss(pred, target) + (1-alpha) * bce(pred, target)
6. 训练脚本(Patch-based + 混合精度 + DataLoader)
思路:对 3D 体积使用 patch 采样(例如 128×128×128 或 96×96×96),用 TorchIO 或 MONAI 的 RandSpatialCrop 进行正/负样本采样,减小显存消耗并提高数据多样性。下面给出示例训练主循环精简版(可扩展):
# train.py (核心片段)
import torch
from torch.utils.data import DataLoader
from monai.transforms import Compose, LoadImage, AddChannel, RandSpatialCrop, RandFlip, ToTensor
from monai.data import CacheDataset, decollate_batch
from monai.metrics import DiceMetric
from model import get_unet
# transforms
train_trans = Compose([
LoadImage(image_only=True),
AddChannel(),
RandSpatialCrop((96,96,96), random_size=False),
RandFlip(prob=0.5, spatial_axis=0),
ToTensor()
])
# dataset (假设有 list_of_dicts 每项包含 image, label)
train_ds = CacheDataset(data=train_files, transform=train_trans)
train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, pin_memory=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_unet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")
scaler = torch.cuda.amp.GradScaler()
for epoch in range(1, epochs+1):
model.train()
epoch_loss = 0
for batch in train_loader:
imgs = batch['image'].to(device)
labs = batch['label'].to(device)
optimizer.zero_grad()
with torch.cuda.amp.autocast():
outputs = model(imgs)
loss = mixed_loss(outputs, labs)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
epoch_loss += loss.item()
# validation & metrics...
Tip:使用
CacheDataset能大幅减少 I/O,但需要更多内存用于缓存。
7. 推理与后处理(连通域与体积过滤)
推理时对整个体积做滑动窗口预测(sliding window inference),MONAI 提供便捷接口 sliding_window_inference:
后处理建议:
-
最小体素数过滤(移除噪声)
-
与肺野掩模相乘以保证预测在肺内
-
计算每个连通域体积(体素数 × voxel_volume → cm³)用于临床参考
8. 评估(Dice / IoU / 体积差异)
常用评估指标:
-
Dice Coefficient(越高越好,1 为完全重合)
-
IoU (Jaccard)
-
体积误差(预测体积 vs 真实体积的相对/绝对误差)
用 MONAI 的评估类可直接计算(见上面 DiceMetric)。
9. 部署(两条易实现路线)
A. 轻量在线 Demo:Streamlit(快速可视化)
把模型导出为 torchscript 或直接在服务器上加载 PyTorch 模型,用 Streamlit 做前端,上载 NIfTI 文件 → 显示切片与预测叠加。
示例 app.py:
import streamlit as st
import nibabel as nib
import numpy as np
import torch
from model import get_unet
@st.cache_resource
def load_model(checkpoint):
model = get_unet()
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
model.eval()
return model
st.title("CT Segmentation Demo")
uploaded = st.file_uploader("上传 NIfTI (.nii/.nii.gz)", type=['nii','gz'])
if uploaded:
img = nib.load(uploaded)
arr = img.get_fdata()
model = load_model('checkpt.pth')
# 预处理、推理、展示若干切片
st.image(...) # 可将切片绘制为 PNG 后展示
运行:
streamlit run app.py --server.port 8501
B. 高性能推理:ONNX + ONNXRuntime 或 TorchScript
导出 ONNX(或 TorchScript)并用 ONNXRuntime 做推理以提升吞吐量:
# export_onnx.py
import torch
model = get_unet()
model.load_state_dict(torch.load('checkpt.pth'))
model.eval()
dummy = torch.randn(1,1,96,96,96)
torch.onnx.export(model, dummy, "unet.onnx", opset_version=11)
用 onnxruntime 加速推理,或把模型包装为 REST API(FastAPI + ONNXRuntime)部署到云服务器。
10. 常见问题(FAQ)与调参建议
-
显存 OOM:减小 patch 大小或 batch_size,使用混合精度,或梯度累积。
-
训练不收敛:检查 HU 窗口/重采样是否一致,确保标签对齐。
-
Dice 很高但视觉差:可能是 class imbalance;尝试 focal loss 或对稀少样本做 oversample。
-
泛化差:做更多数据增强或在不同源医院数据上做微调(domain shift)。
11. 完整项目结构
med_unet_project/ ├─ data_raw/ # 原始 DICOM / NIfTI ├─ data_preprocessed/ # resampled & normalized .nii/.npz ├─ src/ │ ├─ utils/ │ │ ├─ preprocess.py │ │ └─ postprocess.py │ ├─ datasets.py │ ├─ model.py │ ├─ train.py │ ├─ infer.py │ └─ export_onnx.py ├─ notebooks/ ├─ requirements.txt └─ README.md
12. 示例:完整 train.py
下面的脚本为精简版本,真实工程请添加日志、断点保存、早停、学习率调度、混合精度更完整处理。
# train.py (精简)
import os, glob
import torch
from monai.transforms import Compose, LoadImage, AddChannel, RandSpatialCrop, RandFlip, ToTensor
from monai.data import CacheDataset, DataLoader
from model import get_unet
from utils.preprocess import hu_clip_normalize
def make_dataset(data_dir):
# 假设 data_dir 下为 pairs of image,label .nii
files = []
for img_p in glob.glob(os.path.join(data_dir, 'images','*.nii*')):
lbl_p = img_p.replace('images','labels')
files.append({'image': img_p, 'label': lbl_p})
return files
if __name__ == '__main__':
train_files = make_dataset('data_preprocessed/train')
train_trans = Compose([LoadImage(image_only=True), AddChannel(), RandSpatialCrop((96,96,96), random_size=False),
RandFlip(prob=0.5), ToTensor()])
ds = CacheDataset(data=train_files, transform=train_trans)
loader = DataLoader(ds, batch_size=2, num_workers=4, pin_memory=True)
device = torch.device('cuda')
model = get_unet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(1,101):
model.train()
for batch in loader:
img = batch['image'].to(device)
lab = batch['label'].to(device)
out = model(img)
loss = torch.nn.functional.binary_cross_entropy_with_logits(out, lab)
opt.zero_grad(); loss.backward(); opt.step()
print(f"Epoch {epoch} loss {loss.item():.4f}")
if epoch % 10 == 0:
torch.save(model.state_dict(), f'checkpt_epoch{epoch}.pth')
13. 参考与资源
-
U-Net 原始论文(Ronneberger et al., 2015)
-
Medical Segmentation Decathlon(MSD)官方页面与说明
-
MONAI:医疗影像的 PyTorch 框架(GitHub & 文档)
-
TorchIO:医学体积预处理与 patch 采样工具

3万+

被折叠的 条评论
为什么被折叠?



