目录
背景
相比采用表格化数据定量分析,可视化是分割任务定性分析的主要手段。
在多类别图像分割任务中,往往会涉及两种及以上的颜色,一种常用的方法就是使用RGB色彩填充分割target。
最简单的实现方法就是两次for循环遍历图像,逐个像素进行填充,但这种方式耗时较长,采用np.where()等numpy高级函数可以有效加速,缩短程序运行时间。
方法
以腹部器官数据集Synapse为例,其中包含除去背景一类的8类器官(aorta、gallbladder、liver……),这种情况下,可视化就需要八种颜色来分别表示8类器官,所以先定义对应的颜色:
blue = [30,144,255] # aorta
green = [0,255,0] # gallbladder
red = [255,0,0] # left kidney
cyan = [0,255,255] # right kidney
pink = [255,0,255] # liver
yellow = [255,255,0] # pancreas
purple = [128,0,255] # spleen
orange = [255,128,0] # stomach
这里选取了几种对比较为鲜明的颜色。
接下来需要处理一下原始图像以及预测结果,一般来说作为模型的输入,原始图像一般会经过归一化操作,所以乘以255恢复原始像素值,在转换为uint8格式以使用opencv处理。此外,原始图像若是灰度图还需转换为3通道的RGB格式。分割结果默认是非one-hot的一维格式(通过torch.argmax转换得到),所以也需要转换为RGB三维格式:
original_img = original_img * 255.0
original_img = original_img.astype(np.uint8)
original_img = cv2.cvtColor(original_img,cv2.COLOR_GRAY2BGR)
pred = cv2.cvtColor(pred,cv2.COLOR_GRAY2BGR)
接下来是核心部分,以其中一类为例:
original_img = np.where(pred==1, np.full_like(original_img, blue), original_img)
np.where(condition, x, y) 函数定义如下:满足条件(condition),输出x,不满足输出y,具体可参考点击跳转。此处的条件是判断预测结果的某一像素点是否为“1”类别,若是则使用定义好的“blue”RGB值填充(np.full_like的作用是创建一个等于原图大小的全蓝色图像,多余的蓝色区域可以被pred掩膜过滤掉),若不是则输出原始图像的像素值。分别对8类处理完后保存即可,最终可视化结果如图(左:原图,右:分割结果):
完整代码如下,输入分别为原图,预测结果,可视化结果保存路径:
def vis_save(original_img, pred, save_path):
blue = [30,144,255] # aorta
green = [0,255,0] # gallbladder
red = [255,0,0] # left kidney
cyan = [0,255,255] # right kidney
pink = [255,0,255] # liver
yellow = [255,255,0] # pancreas
purple = [128,0,255] # spleen
orange = [255,128,0] # stomach
original_img = original_img * 255.0
original_img = original_img.astype(np.uint8)
original_img = cv2.cvtColor(original_img,cv2.COLOR_GRAY2BGR)
pred = cv2.cvtColor(pred,cv2.COLOR_GRAY2BGR)
original_img = np.where(pred==1, np.full_like(original_img, blue ), original_img)
original_img = np.where(pred==2, np.full_like(original_img, green ), original_img)
original_img = np.where(pred==3, np.full_like(original_img, red ), original_img)
original_img = np.where(pred==4, np.full_like(original_img, cyan ), original_img)
original_img = np.where(pred==5, np.full_like(original_img, pink ), original_img)
original_img = np.where(pred==6, np.full_like(original_img, yellow), original_img)
original_img = np.where(pred==7, np.full_like(original_img, purple), original_img)
original_img = np.where(pred==8, np.full_like(original_img, orange), original_img)
original_img = cv2.cvtColor(original_img,cv2.COLOR_BGR2RGB)
cv2.imwrite(save_path, original_img)
实测结果表明,使用np.where()函数实现可比双for循环遍历快6倍左右。