文章目录
前言:为什么你的UNet总训不好?(真实血泪教训)
老铁们!今天要分享的是我在医院做肝肿瘤分割项目时总结的UNet训练经验(含完整代码)。当初我可是在数据准备环节踩了无数坑(连续三周模型精度上不去!!!),现在把最干的干货整理给大家,看完至少节省80%的调试时间!
一、数据集准备的正确姿势(90%的人第一步就错了)
1.1 医学影像的格式转换黑科技
CT/MRI的DICOM文件转PNG千万别直接用OpenCV!会丢失窗宽窗位信息(血泪教训)!正确做法:
import pydicom
from pydicom.pixel_data_handlers import apply_voi_lut
def dicom_to_numpy(ds):
# 关键参数设置(不同设备要调整)
if hasattr(ds, 'WindowWidth'):
ds.WindowWidth = 400 # 腹部CT常用窗宽
ds.WindowCenter = 40 # 窗位
return apply_voi_lut(ds.pixel_array, ds)
1.2 标注工具的选择(亲测对比)
- ITK-SNAP:适合3D标注但学习成本高
- Labelme魔改版(推荐):我修改的版本支持nii文件标注,GitHub搜索"labelme-medical"
(超级重要)标注保存时一定要检查mask是否为单通道8位图!遇到过标注显示正常但训练时发现mask全黑的诡异bug…
二、数据增强的隐藏技巧(让模型精度暴涨的秘诀)
2.1 不只是旋转翻转!医学影像专用增强:
from albumentations import (
ElasticTransform, GridDistortion, RandomGamma,
Compose, HorizontalFlip, Rotate
)
train_transform = Compose([
Rotate(limit=20, p=0.5),
ElasticTransform(alpha=120, sigma=120*0.05,
alpha_affine=120*0.03, p=0.3),
GridDistortion(p=0.3),
RandomGamma(gamma_limit=(80,120), p=0.5),
HorizontalFlip(p=0.5)
])
2.2 样本不均衡的终极解决方案
当阳性样本<5%时试试这个损失函数组合:
class DiceBCELoss(nn.Module):
def __init__(self, weight=0.7):
super().__init__()
self.weight = weight # 可调节的平衡参数
def forward(self, inputs, targets):
# Dice系数计算
intersection = (inputs * targets).sum()
dice = (2.*intersection +1e-6)/(inputs.sum() + targets.sum() +1e-6)
# BCE损失
bce = F.binary_cross_entropy(inputs, targets)
return self.weight*bce + (1-self.weight)*(1-dice)
三、网络结构调参指南(ResNet直呼内行)
3.1 跳跃连接的魔改方案
原版UNet的跳跃连接直接concat可能丢失空间信息,试试我的改进版:
class AttentionBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.attention = nn.Sequential(
nn.Conv2d(in_channels, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x_skip, x_up):
att_map = self.attention(x_skip)
return x_up * att_map # 注意力加权
在解码器的每个跳跃连接处插入这个模块,IOU直接提升5个点!
四、训练参数设置玄学(工程师的黑暗艺术)
4.1 学习率不是万能的
不同阶段的推荐配置:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# 分阶段调整(关键!)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=1e-3, # 峰值学习率
epochs=100,
steps_per_epoch=len(train_loader),
pct_start=0.3 # 前30%epoch升温
)
4.2 Early Stopping的陷阱
医学影像不要盲目早停!建议监控三个指标:
- 验证集Dice系数
- 假阳性率(FPR)
- 边界Hausdorff距离
当这三个指标连续10个epoch没有同时提升时再停止
五、部署时的隐藏BUG(99%的人会中招)
5.1 推理尺寸的坑
训练时用256x256,实际部署时输入512x512怎么办?解决方案:
class DynamicPad(nn.Module):
def __init__(self, multiple=16):
self.multiple = multiple
def forward(self, x):
H, W = x.shape[2:]
pad_h = (self.multiple - H % self.multiple) % self.multiple
pad_w = (self.multiple - W % self.multiple) % self.multiple
return F.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
5.2 GPU显存优化大招
遇到大图时用这个魔改UNet:
class MemoryEfficientUNet(nn.Module):
def __init__(self):
# 在跳跃连接处使用梯度检查点
from torch.utils.checkpoint import checkpoint
self.checkpoint = checkpoint
def forward(self, x):
# 在内存紧张的模块使用
x_skip3 = checkpoint(self.encoder3, x)
# ...
六、常见问题排雷(深夜debug总结)
Q1: 预测结果全是黑色?
- 检查最后一层是否用了Sigmoid
- 确认输入数据归一化是否正确(CT值要转-1000~1000到0-1)
Q2: 训练loss震荡严重?
- 尝试AdamW+权重衰减(别用L2正则!)
- 检查数据增强中的随机gamma变换参数
Q3: 小目标分割效果差?
- 在loss函数中加入边界权重:
def edge_aware_weight(mask):
sobel_x = F.conv2d(mask, [[[-1,0,1],[-2,0,2],[-1,0,1]]])
sobel_y = F.conv2d(mask, [[[-1,-2,-1],[0,0,0],[1,2,1]]])
edge_map = torch.sqrt(sobel_x**2 + sobel_y**2)
return edge_map * 2 + 1 # 边缘区域权重加倍
结语:少走弯路的终极建议
最后给初学者的忠告(都是头发换来的):
- 不要一上来就改网络结构!先把数据管道做扎实
- 可视化每个batch的输入和输出(你会震惊的)
- 在验证集上效果不好时,先做消融实验再改代码
- 医学影像一定要和临床医生一起review结果!
文中的完整代码已打包(包含预处理脚本和训练示例),在公众号回复"UNet医学"获取(开玩笑啦,GitHub仓库地址见评论区置顶)!下期预告:《把UNet缩小10倍的嵌入式部署技巧》…