U-Net是医学图像分割中最经典的分割网络,目前(2020.04.03)已经有过万的引用。
过去的这5年里,针对提升U-Net有非常多的工作,但对于U-Net本身,不同的人实现出来性能大概率不一样,Fabian组开发的这个U-Net实现,是目前笔者看到性能最好的,但同时这个代码每一部分都做了精细的优化:预处理,数据增广,后处理,集成等等,不同模块之间耦合程度较高,因此想要把自己的一些想法加入到这个框架,首先要对这个代码有较为深入的了解。接下来会陆陆续续总结下自己肤浅的理解,笔者没经过专业的编程训练,基本都是看网课自学的,水平很菜,理解不到位的地方,恳请读者多多指正。
Automated Design of Deep Learning Methods for Biomedical Image Segmentationarxiv.org https://github.com/MIC-DKFZ/nnUNetgithub.com对nnUNet代码的学习分下面四个层次
- 做了什么?
- 为什么做这个?
- 代码怎么实现的?
- 为什么这么实现?
准备好数据集后,要运行
nnUNet_plan_and_preprocess -t XXX --verify_dataset_integrity
这一步要做的事情是规划实验以及对图像数据进行预处理,设计到的函数见上图。本贴介绍crop。
- 做了什么?对图像背景进行裁剪
- 为什么做这个?以BraTS为代表的数据集,背景有一部分是全黑的(灰度值为0),这部分没有信息含量,裁掉后不影响后续的学习过程,反而能显著减小图像大小,减少计算量。
如何实现crop?
nnUNet_plan_and_preprocess中调用用了下面的crop函数
def crop(task_string, override=False, num_threads=default_num_threads):
cropped_out_dir = join(nnUNet_cropped_data, task_string)
maybe_mkdir_p(cropped_out_dir)
if override and isdir(cropped_out_dir):
shutil.rmtree(cropped_out_dir)
maybe_mkdir_p(cropped_out_dir)
splitted_4d_output_dir_task = join(nnUNet_raw_data, task_string)
lists, _ = create_lists_from_splitted_dataset(splitted_4d_output_dir_task)
imgcrop = ImageCropper(num_threads, cropped_out_dir)
imgcrop.run_cropping(lists, overwrite_existing=override)
shutil.copy(join(nnUNet_raw_data, task_string, "dataset.json"), cropped_out_dir)
核心是下面两行
imgcrop = ImageCropper(num_threads, cropped_out_dir)
imgcrop.run_cropping(lists, overwrite_existing=override)
def run_cropping(self, list_of_files, overwrite_existing=False, output_folder=None):
"""
also copied ground truth nifti segmentation into the preprocessed folder so that we can use them for evaluation
on the cluster
:param list_of_files: list of list of files [[PATIENTID_TIMESTEP_0000.nii.gz], [PATIENTID_TIMESTEP_0000.nii.gz]]
:param overwrite_existing:
:param output_folder:
:return:
"""
if output_folder is not None:
self.output_folder = output_folder
output_folder_gt = os.path.join(self.output_folder, "gt_segmentations")
maybe_mkdir_p(output_folder_gt)
for j, case in enumerate(list_of_files):
if case[-1] is not None:
shutil.copy(case[-1], output_folder_gt)
list_of_args = []
for j, case in enumerate(list_of_files):
case_identifier = get_case_identifier(case)
list_of_args.append((case, case_identifier, overwrite_existing))
p = Pool(self.num_threads)
p.map(self._load_crop_save_star, list_of_args)
p.close()
p.join()
这个函数前面的部分都是在准备参数,然后通过
p.map(self._load_crop_save_star, list_of_args)
把参数传入这个私有方法后,调用函数load_crop_save,corp的过程发生在crop_from_list_of_files,最终的结果会保存成npz文件,boundingbox的信息存储在pkl中。
def _load_crop_save_star(self, args):
return self.load_crop_save(*args)
def load_crop_save(self, case, case_identifier, overwrite_existing=False):
try:
print(case_identifier)
if overwrite_existing
or (not os.path.isfile(os.path.join(self.output_folder, "%s.npz" % case_identifier))
or not os.path.isfile(os.path.join(self.output_folder, "%s.pkl" % case_identifier))):
data, seg, properties = self.crop_from_list_of_files(case[:-1], case[-1])
all_data = np.vstack((data, seg))
np.savez_compressed(os.path.join(self.output_folder, "%s.npz" % case_identifier), data=all_data)
with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f:
pickle.dump(properties, f)
except Exception as e:
print("Exception in", case_identifier, ":")
print(e)
raise e
负责crop的两个主要函数都被定义成了静态方法,
@staticmethod
def crop_from_list_of_files(data_files, seg_file=None):
data, seg, properties = load_case_from_list_of_files(data_files, seg_file)
return ImageCropper.crop(data, properties, seg)
@staticmethod
def crop(data, properties, seg=None):
shape_before = data.shape
data, seg, bbox = crop_to_nonzero(data, seg, nonzero_label=-1)
shape_after = data.shape
print("before crop:", shape_before, "after crop:", shape_after, "spacing:",
np.array(properties["original_spacing"]), "n")
# pkl文件中存储如下信息:bb,类个数(包含背景),crop后的size
properties["crop_bbox"] = bbox
properties['classes'] = np.unique(seg)
seg[seg < -1] = 0
properties["size_after_cropping"] = data[0].shape
return data, seg, properties
下面陈列下主要的代码,而不是复制粘贴整个函数。
# 数据的shape是(C, X, Y, Z) 或者 (C, X, Y)
nonzero_mask = np.zeros(data.shape[1:], dtype=bool)
for c in range(data.shape[0]):
this_mask = data[c] != 0 不等于0的都认为是背景
nonzero_mask = nonzero_mask | this_mask
nonzero_mask = binary_fill_holes(nonzero_mask) # scipy.ndimage
# 获取bounding box的函数。这个在日常会经常用到
def get_bbox_from_mask(mask, outside_value=0):
mask_voxel_coords = np.where(mask != outside_value)
minzidx = int(np.min(mask_voxel_coords[0]))
maxzidx = int(np.max(mask_voxel_coords[0])) + 1 # +1如果超出图像大小怎么办?
minxidx = int(np.min(mask_voxel_coords[1]))
maxxidx = int(np.max(mask_voxel_coords[1])) + 1
minyidx = int(np.min(mask_voxel_coords[2]))
maxyidx = int(np.max(mask_voxel_coords[2])) + 1
return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]]
# 根据bbox提取ROI,这个也经常会用到,为什么用slice做索引,而不直接bbox[0][0]:bbox[0][1]...
def crop_to_bbox(image, bbox):
assert len(image.shape) == 3, "only supports 3d images"
resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1]))
return image[resizer]
最终每个病例的pkl文件信息如下
classes中的-1表示非0背景
为什么这么实现?
以下内容都很主观,请务必带着批判和质疑审视,也恳请知道答案的高手不吝解惑。
- Q: 为什么把crop定义成一个类?
A: 从上面的介绍可以看出crop这个过程虽然目的很简单,但涉及到很多其他需求,比如数据的读取,存储,bounding box的记录,定义成类使得不同小功能之间更加模块化。
- Q: run_cropping中为什么把参数传给私有方法_load_crop_save_star,而不是直接传给load_crop_save方法?
A: 仅从功能上来说,实际上没有特殊意义;猜想应该是作者为了后续的扩展性及代码的调用便捷考虑,后续可能会扩展 load_crop_save_star的操作,项目较大,代码抽象度高的时候,这么做的优势就显出来。
- 为什么把crop_from_list_of_files和crop定义为静态方法?
A: 静态方法的文档。方便用类方法直接调用,testing的时候也会用到这个方法。
- Q: ImageCropper类中最后两个方法load_properties,save_properties在这里没用到吗?
A: 目前看来是的。这两个函数功能很清楚,后续如果看到有用了再补充。
几个小问题,可能需要调试下代码才知道
- 数据的shape是(C, X, Y, Z),C是什么?模态数量
- get_bbox_from_mask,max坐标+1如果超出图像大小怎么办?
- 根据bbox提取ROI,为什么用slice做索引,而不直接bbox[0][0]:bbox[0][1]...
感谢评论区安兴乐的指导。