基于 UNet 的医学图像分割实战:从肺部 CT 分割到模型部署

部署运行你感兴趣的模型镜像

目录

  1. 为什么选择 UNet + CT 分割

  2. 数据来源(推荐数据集与许可)

  3. 环境与依赖(conda / pip)

  4. 数据预处理(核心步骤 + 代码)

  5. 模型:3D U-Net(MONAI 实现) + 损失与度量

  6. 训练脚本(含 patch 采样、混合精度)

  7. 推理与后处理(例:连通域、最小体积过滤)

  8. 评估(Dice / IoU / 体积误差)

  9. 部署(Streamlit Demo 与 ONNX 导出示例)

  10. 常见问题与调参建议

  11. 完整项目结构与附件(复制即用)

  12. 结语与扩展方向


1. 为什么选择 UNet + CT 分割

U-Net 架构自 2015 年推出以来成为医学图像分割的基石:编码器-解码器结构精于提取语义并恢复细节,易于扩展到 3D,是 CT/ MRI 等体积分割的首选之一。arXiv


2. 数据来源

  • Medical Segmentation Decathlon (MSD):包含多种器官/病灶的 3D 分割任务,适合初学者做泛化实验;官方站点与论文说明了数据组织与评估方式。medicaldecathlon.com+1

  • LIDC / LUNA16(肺结节):用于肺结节检测/分割研究(如需结节级别评估可用)。(可根据需要并遵守数据许可下载)

  • 本文示例用 MSD 的 Spleen / Lung 或公开镜像小样本做演示(若使用医院 DICOM,注意机构审批与脱敏)。

数据使用须遵守原数据许可与医院伦理(IRB)要求。


3. 环境搭建

conda create -n medseg python=3.10 -y
conda activate medseg
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
pip install monai[all] torchio nibabel pydicom SimpleITK scikit-image opencv-python matplotlib streamlit onnx onnxruntime
  • 推荐使用 NVIDIA GPU(至少 12GB,训练 patch-size 较大时建议 24GB)。

  • 推荐使用 MONAI(专为医学影像设计,内置常用 transforms、网络与训练工具)。GitHub


4. 数据预处理(CT 专用核心步骤 + 代码)

CT 体素的物理间距(spacing)与 HU 值对训练影响非常大,必须统一处理。

关键步骤

  1. 读取 DICOM / NIfTI(SimpleITK 或 nibabel)

  2. HU 校正与截断(常用窗:[-1000, 400][-1350, 800],视任务)

  3. 重采样(resample)到统一 voxel spacing(例如 1.0 × 1.0 × 1.0 mm

  4. 肺野提取(阈值 + 连通域,或使用快速网络)

  5. 存成 NifTI 或 .npz 供训练读取(减少 I/O 时间)

  6. 生成训练用 patch(patch-based training)

下面给出关键函数(SimpleITK 实现):

# utils/preprocess.py
import SimpleITK as sitk
import numpy as np

def read_dicom_series(folder):
    reader = sitk.ImageSeriesReader()
    series_IDs = reader.GetGDCMSeriesFileNames(folder)
    reader.SetFileNames(series_IDs)
    image = reader.Execute()
    arr = sitk.GetArrayFromImage(image)  # z,y,x
    spacing = image.GetSpacing()[::-1]  # sitk: x,y,z -> convert to z,y,x
    origin = image.GetOrigin()
    return arr, spacing, origin

def resample_image(arr, spacing, new_spacing=(1.0,1.0,1.0), is_label=False):
    img = sitk.GetImageFromArray(arr)
    img.SetSpacing((spacing[2], spacing[1], spacing[0]))  # sitk expects x,y,z
    orig_size = img.GetSize()
    new_size = [
        int(np.round(orig_size[i] * (img.GetSpacing()[i] / new_spacing[i])))
        for i in range(3)
    ]
    resampler = sitk.ResampleImageFilter()
    resampler.SetOutputSpacing(new_spacing)
    resampler.SetSize(new_size)
    resampler.SetInterpolator(sitk.sitkNearestNeighbor if is_label else sitk.sitkLinear)
    resampled = resampler.Execute(img)
    return sitk.GetArrayFromImage(resampled)

def hu_clip_normalize(arr, hu_min=-1000, hu_max=400):
    arr = np.clip(arr, hu_min, hu_max)
    arr = (arr - hu_min) / (hu_max - hu_min)  # 0-1
    return arr.astype(np.float32)

Tip:预处理最好一次做完并存成 NIfTI(或 torch.save 的 tensor),训练读取速度更快。


5. 模型:3D U-Net(MONAI 实现)与损失函数

使用 MONAI 可以非常简洁地构建 3D U-Net:

# model.py
from monai.networks.nets import UNet

def get_unet(in_channels=1, out_channels=1, channels=(16,32,64,128,256)):
    model = UNet(
        dimensions=3,
        in_channels=in_channels,
        out_channels=out_channels,
        channels=channels,
        strides=(2,2,2,2),
        num_res_units=2,
        norm='batch'
    )
    return model

损失函数与度量

常用组合:DiceLoss + BCEWithLogitsLoss,指标使用 DiceCoefficientIoU

import monai
loss = monai.losses.DiceLoss(sigmoid=True)
bce = torch.nn.BCEWithLogitsLoss()
def mixed_loss(pred, target, alpha=0.5):
    return alpha * loss(pred, target) + (1-alpha) * bce(pred, target)

6. 训练脚本(Patch-based + 混合精度 + DataLoader)

思路:对 3D 体积使用 patch 采样(例如 128×128×128 或 96×96×96),用 TorchIO 或 MONAI 的 RandSpatialCrop 进行正/负样本采样,减小显存消耗并提高数据多样性。下面给出示例训练主循环精简版(可扩展):

# train.py (核心片段)
import torch
from torch.utils.data import DataLoader
from monai.transforms import Compose, LoadImage, AddChannel, RandSpatialCrop, RandFlip, ToTensor
from monai.data import CacheDataset, decollate_batch
from monai.metrics import DiceMetric
from model import get_unet

# transforms
train_trans = Compose([
    LoadImage(image_only=True),
    AddChannel(),
    RandSpatialCrop((96,96,96), random_size=False),
    RandFlip(prob=0.5, spatial_axis=0),
    ToTensor()
])

# dataset (假设有 list_of_dicts 每项包含 image, label)
train_ds = CacheDataset(data=train_files, transform=train_trans)
train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, pin_memory=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_unet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

scaler = torch.cuda.amp.GradScaler()
for epoch in range(1, epochs+1):
    model.train()
    epoch_loss = 0
    for batch in train_loader:
        imgs = batch['image'].to(device)
        labs = batch['label'].to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss = mixed_loss(outputs, labs)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
    # validation & metrics...

Tip:使用 CacheDataset 能大幅减少 I/O,但需要更多内存用于缓存。


7. 推理与后处理(连通域与体积过滤)

推理时对整个体积做滑动窗口预测(sliding window inference),MONAI 提供便捷接口 sliding_window_inference

后处理建议:

  • 最小体素数过滤(移除噪声)

  • 与肺野掩模相乘以保证预测在肺内

  • 计算每个连通域体积(体素数 × voxel_volume → cm³)用于临床参考


8. 评估(Dice / IoU / 体积差异)

常用评估指标:

  • Dice Coefficient(越高越好,1 为完全重合)

  • IoU (Jaccard)

  • 体积误差(预测体积 vs 真实体积的相对/绝对误差)

用 MONAI 的评估类可直接计算(见上面 DiceMetric)。


9. 部署(两条易实现路线)

A. 轻量在线 Demo:Streamlit(快速可视化)

把模型导出为 torchscript 或直接在服务器上加载 PyTorch 模型,用 Streamlit 做前端,上载 NIfTI 文件 → 显示切片与预测叠加。

示例 app.py

import streamlit as st
import nibabel as nib
import numpy as np
import torch
from model import get_unet

@st.cache_resource
def load_model(checkpoint):
    model = get_unet()
    model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
    model.eval()
    return model

st.title("CT Segmentation Demo")
uploaded = st.file_uploader("上传 NIfTI (.nii/.nii.gz)", type=['nii','gz'])
if uploaded:
    img = nib.load(uploaded)
    arr = img.get_fdata()
    model = load_model('checkpt.pth')
    # 预处理、推理、展示若干切片
    st.image(...)  # 可将切片绘制为 PNG 后展示

运行:

streamlit run app.py --server.port 8501

B. 高性能推理:ONNX + ONNXRuntime 或 TorchScript

导出 ONNX(或 TorchScript)并用 ONNXRuntime 做推理以提升吞吐量:

# export_onnx.py
import torch
model = get_unet()
model.load_state_dict(torch.load('checkpt.pth'))
model.eval()
dummy = torch.randn(1,1,96,96,96)
torch.onnx.export(model, dummy, "unet.onnx", opset_version=11)

onnxruntime 加速推理,或把模型包装为 REST API(FastAPI + ONNXRuntime)部署到云服务器。


10. 常见问题(FAQ)与调参建议

  • 显存 OOM:减小 patch 大小或 batch_size,使用混合精度,或梯度累积。

  • 训练不收敛:检查 HU 窗口/重采样是否一致,确保标签对齐。

  • Dice 很高但视觉差:可能是 class imbalance;尝试 focal loss 或对稀少样本做 oversample。

  • 泛化差:做更多数据增强或在不同源医院数据上做微调(domain shift)。


11. 完整项目结构

med_unet_project/
├─ data_raw/                # 原始 DICOM / NIfTI
├─ data_preprocessed/       # resampled & normalized .nii/.npz
├─ src/
│  ├─ utils/
│  │   ├─ preprocess.py
│  │   └─ postprocess.py
│  ├─ datasets.py
│  ├─ model.py
│  ├─ train.py
│  ├─ infer.py
│  └─ export_onnx.py
├─ notebooks/
├─ requirements.txt
└─ README.md

12. 示例:完整 train.py

下面的脚本为精简版本,真实工程请添加日志、断点保存、早停、学习率调度、混合精度更完整处理。

# train.py (精简)
import os, glob
import torch
from monai.transforms import Compose, LoadImage, AddChannel, RandSpatialCrop, RandFlip, ToTensor
from monai.data import CacheDataset, DataLoader
from model import get_unet
from utils.preprocess import hu_clip_normalize

def make_dataset(data_dir):
    # 假设 data_dir 下为 pairs of image,label .nii
    files = []
    for img_p in glob.glob(os.path.join(data_dir, 'images','*.nii*')):
        lbl_p = img_p.replace('images','labels')
        files.append({'image': img_p, 'label': lbl_p})
    return files

if __name__ == '__main__':
    train_files = make_dataset('data_preprocessed/train')
    train_trans = Compose([LoadImage(image_only=True), AddChannel(), RandSpatialCrop((96,96,96), random_size=False),
                           RandFlip(prob=0.5), ToTensor()])
    ds = CacheDataset(data=train_files, transform=train_trans)
    loader = DataLoader(ds, batch_size=2, num_workers=4, pin_memory=True)

    device = torch.device('cuda')
    model = get_unet().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=1e-4)
    for epoch in range(1,101):
        model.train()
        for batch in loader:
            img = batch['image'].to(device)
            lab = batch['label'].to(device)
            out = model(img)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(out, lab)
            opt.zero_grad(); loss.backward(); opt.step()
        print(f"Epoch {epoch} loss {loss.item():.4f}")
        if epoch % 10 == 0:
            torch.save(model.state_dict(), f'checkpt_epoch{epoch}.pth')

13. 参考与资源

  • U-Net 原始论文(Ronneberger et al., 2015)

  • Medical Segmentation Decathlon(MSD)官方页面与说明

  • MONAI:医疗影像的 PyTorch 框架(GitHub & 文档)

  • TorchIO:医学体积预处理与 patch 采样工具

您可能感兴趣的与本文相关的镜像

ComfyUI

ComfyUI

AI应用
ComfyUI

ComfyUI是一款易于上手的工作流设计工具,具有以下特点:基于工作流节点设计,可视化工作流搭建,快速切换工作流,对显存占用小,速度快,支持多种插件,如ADetailer、Controlnet和AnimateDIFF等

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

权泽谦

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值