1. 准备工作
硬件:需要配备nvidia独显,RTX3090或RTX2080系列等。较大的显存对分割任务能有一定的帮助。也可以租用谷歌、阿里等服务器,价格一般按提供的使用时间和算力计算。
软件:一般使用ubuntu,借助anaconda配置pytorch环境。IDE个人推荐使用Pycharm。
https://pytorch.org/get-started/locally/
2. 数据预处理
2.1 package
一般的二维图像处理依赖opencv,对于医学数据,如MRI、CT等,可借助SimpleITK,NiBabel等包进行处理。
2.2 基本处理
医学图像文件后缀名一般为nii,nii.gz,dicom等,例举使用simpleitk和opncv的相关基本操作
import SimpleITK as sitk
import cv2
img = sitk.ReadImage(img_filename)
img_array = sitk.GetArrayFromImage(img)
img_array = np.transpose(img_array, (1, 2, 0)) # x,y,z--210 120-yxz
spacing = img.GetSpacing()
origin = img.GetOrigin()
direction = img.GetDirection()
reimg.SetOrigin(origin)
reimg.SetSpacing(tar_spacing)
reimg.SetDirection(direction)
sitk.WriteImage(relab, join(preproc_path, '%s_prelab.nrrd') % filename_now)
# normalization
def norm(img_array):
mri_max = np.amax(img_array)
mri_min = np.amin(img_array)
mri_img = ((img_array - mri_min) / (mri_max - mri_min)) * 255
mri_img = mri_img.astype('uint8')
return mri_img
# contrast limited adaptive histogram equalization
def clahe(mri_img):
h, w, d = mri_img.shape
img_clahe_add = np.zeros_like(mri_img)
for k in range(d):
temp = mri_img[:, :, k]
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
img_clahe = clahe.apply(temp)
# cv2.imshow('mri', np.concatenate([temp,img_clahe], 1))
# cv2.waitKey(1)
img_clahe_add[:, :, k] = img_clahe
return img_clahe_add
3. 训练测试集划分
对于公开数据集,一般已经固定了训练集、验证集和测试集。对于私有数据集,则需要自己进行划分,一般采用交叉验证的方式以说明模型的范化性。
例如k折交叉验证(K-fold cross validation),就是把样本集S分成k份,分别使用其中的(k-1)份作为训练集,剩下的1份作为交叉验证集,最后取最后的平均误差,来评估这个模型。
4. 评价指标
每一词训练得到一个输出label,和医生手工标注的金标准进行比较,计算损失并反向传播,不断重复 。
医学上常用的指标为Dice系数,Dice系数是一种集合相似度度量函数,通常用于计算两个样本的相似度,取值范围在[0,1]:
Dice:
其中 |X∩Y| 是X和Y之间的交集,|X|和|Y|分表表示X和Y的元素的个数,其中,分子的系数为2,是因为分母存在重复计算X和Y之间的共同元素的原因。
Dice Loss:
https://zhuanlan.zhihu.com/p/86704421
5. 代码实现
工程实现引用"Attention U-Net: Learning Where to Look for the Pancreas", MIDL'18, Amsterdam
https://github.com/ozan-oktay/Attention-Gated-Networks
5.1 参数配置
涉及大量参数时,可以使用json、yaml等文件记录需要配置的参数,并在主函数入口进行读取。
如下yaml格式文件,保存需要的变量参数
manual_seed: 0
device:
cuda: 0
data:
data_path: data/resampled0302
aug:
zoom_shape: [160,160,32]
shift_val: [0.1,0.1]
rotate_val: 5.0
scale_val: [1.0,1.0]
random_flip_prob: 0.5
train:
is_train: True
is_test: True
n_epochs: 1200
batch_size: 1
model:
task: segment
model_name: matt
criterion: dice_ce
in_channels: 1
n_classes: 2
optimizer: adam
feature_scale: 8
在初始化时,