即插即用的自动数据增强方法MedAugment,可以将自动数据增强引入到医学图像分析领域,显著降低了数据增强对经验的依赖。

今天主要向大家介绍一种名为 MedAugment 的自动数据增强方法,旨在将自动数据增强技术引入医学图像分析领域。该方法通过将增强空间划分为两种:

  • 像素增强空间
  • 空间增强空间

从而解决了自然图像和医学图像之间的差异。论文提出了一种新颖的操作采样策略,用于从增强空间中采样数据增强操作。为了证明 MedAugment 的性能和泛化性能,作者在四个分类数据集和三个分割数据集上进行了广泛的实验,并表明 MedAugment 优于大多数最先进的数据增强方法,例如比较主流的 AutoAugment 和 RandAugment 等自动数据增强方法。

虽然说是用于医学图像领域,但完全可以稍作修改照搬到自然图像即可。

动机

为什么不直接使用现在的自动数据增强方法?作者认为,这些方法最初设计用于自然图像,并不能直接应用于医学图像分析领域。此外,大部分自动数据增强方法最初也是针对图像分类任务设计的,而在医学图像分析领域,图像分割是一个核心任务。因此,目前缺乏一种适用于医学图像分析的通用且强大的自动数据增强方法,本文提出了 MedAugment 来填补这一空白。

方法

MedAugment_医学图像

整体框架介绍

MedAugment 的原理其实很简单,其框架实现如上图1所示。在该方法中,作者设计了两个增强空间 Ap 和 As,分别包含六个和八个数据增强操作。这样我们便有14种数据增强操作(可根据自己的数据集特点自由发挥)。

增强空间

为了适应医学图像分析领域的特点,作者对具体的数据增强方法进行了细心的设置。首先,从常见的数据增强操作开始,然后根据 MIA 领域的特点进行筛选,排除了不适合医学图像的操作,如反转、均衡化和反转等操作,这些操作可能破坏医学图像中的细节和特征。接下来,我们将数据增强操作分为像素级和空间级操作,并构建了两个增强空间,即像素增强空间Ap和空间增强空间As。Ap和As分别包括与像素和空间相关的数据增强操作。需要注意的是,Ap中的数据增强操作不适用于掩膜。【源码是基于 Albumentations 软件包实现了这些数据增强操作,大家也可以针对自身任务特点去设计】

采样策略

由于医学图像对亮度等属性非常敏感,并且我们观察到Ap中连续的操作可能会导致不真实的输出医学图像,因此本文设计了一种新颖的操作采样策略,用于从Ap和As中采样操作。

具体而言,我们随机采样每个分支的M个数据增强操作,其中从Ap中采样的操作数量不超过一个。我们经过权衡决定了M的取值范围。对于连续的数据增强操作,我们需要谨慎考虑连续操作的数量。使用更多的连续操作可能进一步提高模型的泛化能力,但过多的连续操作可能生成与原始图像差距较大的图像。

因此,我们确定M的上限为3。由于本文方法并行使用数据增强操作,将M设置为1没有意义,因为这会退化为单个操作而没有组合效果。基于这些考虑,作者设计了M = {2, 3}。给定M = {2, 3},我们生成了四种从Ap和As中采样的组合方式,分别是1 + 2、0 + 3、1 + 1和0 + 2。这个采样组合的数量解释了为什么分支数N = 4。

MedAugment_医学图像_02

为了更好的扩展,我们也可以将 MedAugment 中的 N 设计为可扩展到其他值,并使用替换采样的方式进行采样。同时,单独的分支也可以被屏蔽。当N = 1并屏蔽单独的分支时,MedAugment 可以执行一对一的数据增强。通过图2的比较可以看出,我们的MedAugment更适合医学图像,而现有的方法可能会生成不真实的增强图像。在最差的情况下,一些增强图像被认为是无意义的,因为存在过多的“噪音”,或者几乎没有有用的信息。这些增强图像可能在深度学习模型上能够被正确识别,但从医学的角度来看是没有意义的。

超参数映射

MedAugment_医学图像_03

源码实现 
import albumentations as A  
import torch  
import math  
import random  
import os  
import cv2  
import shutil  
import numpy as np  
import argparse  
from torchvision import transforms  
from PIL import Image  
  
  
def make_odd(num):  
    num = math.ceil(num)  
    if num % 2 == 0:  
        num += 1  
    return num  
  
  
def med_augment(data_path, name, level, number_branch, mask_i=False, shield=False):  
    if mask_i:  
        image_path = f"{data_path}{name}"  
        mask_path = f"{image_path}_mask"  
        output_path = f"{os.path.dirname(os.path.dirname(data_path))}/medaugment/{name}/"  
        out_mask = f"{os.path.dirname(os.path.dirname(data_path))}/medaugment/{name}_mask/"  
    else:  
        image_path = data_path + name  
        output_path = f"{os.path.dirname(os.path.dirname(os.path.dirname(data_path)))}/medaugment/training/{name}/"  
  
    transform = A.Compose([  
        A.ColorJitter(brightness=0.04 * level, contrast=0, saturation=0, hue=0, p=0.2 * level),  
        A.ColorJitter(brightness=0, contrast=0.04 * level, saturation=0, hue=0, p=0.2 * level),  
        A.Posterize(num_bits=math.floor(8 - 0.8 * level), p=0.2 * level),  
        A.Sharpen(alpha=(0.04 * level, 0.1 * level), lightness=(1, 1), p=0.2 * level),  
        A.GaussianBlur(blur_limit=(3, make_odd(3 + 0.8 * level)), p=0.2 * level),  
        A.GaussNoise(var_limit=(2 * level, 10 * level), mean=0, per_channel=True, p=0.2 * level),  
        A.Rotate(limit=4 * level, interpolation=1, border_mode=0, value=0, mask_value=None, rotate_method='largest_box',  
                 crop_border=False, p=0.2 * level),  
        A.HorizontalFlip(p=0.2 * level),  
        A.VerticalFlip(p=0.2 * level),  
        A.Affine(scale=(1 - 0.04 * level, 1 + 0.04 * level), translate_percent=None, translate_px=None, rotate=None,  
                 shear=None, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,  
                 keep_ratio=True, p=0.2 * level),  
        A.Affine(scale=None, translate_percent=None, translate_px=None, rotate=None,  
                 shear={'x': (0, 2 * level), 'y': (0, 0)}  
                 , interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,  
                 keep_ratio=True, p=0.2 * level),  # x  
        A.Affine(scale=None, translate_percent=None, translate_px=None, rotate=None,  
                 shear={'x': (0, 0), 'y': (0, 2 * level)}  
                 , interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,  
                 keep_ratio=True, p=0.2 * level),  
        A.Affine(scale=None, translate_percent={'x': (0, 0.02 * level), 'y': (0, 0)}, translate_px=None, rotate=None,  
                 shear=None, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,  
                 keep_ratio=True, p=0.2 * level),  
        A.Affine(scale=None, translate_percent={'x': (0, 0), 'y': (0, 0.02 * level)}, translate_px=None, rotate=None,  
                 shear=None, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,  
                 keep_ratio=True, p=0.2 * level)  
    ])  
  
    for j, file_name in enumerate(os.listdir(image_path)):  
        if file_name.endswith(".png") or file_name.endswith(".jpg"):  
            file_path = os.path.join(image_path, file_name)  
            file_n, file_s = file_name.split(".")[0], file_name.split(".")[1]  
            image = cv2.imread(file_path)  
            if mask_i: mask = cv2.imread(f"{mask_path}/{file_n}_mask.{file_s}")  
            strategy = [(1, 2), (0, 3), (0, 2), (1, 1)]  
            for i in range(number_branch):  
                if number_branch != 4:  
                    employ = random.choice(strategy)  
                else:  
                    index = random.randrange(len(strategy))  
                    employ = strategy.pop(index)  
                level, shape = random.sample(transform[:6], employ[0]), random.sample(transform[6:], employ[1])  
                img_transform = A.Compose([*level, *shape])  
                random.shuffle(img_transform.transforms)  
                if not os.path.exists(output_path): os.makedirs(output_path)  
                if mask_i:  
                    transformed = img_transform(image=image, mask=mask)  
                    transformed_image, transformed_mask = transformed['image'], transformed['mask']  
                    cv2.imwrite(f"{output_path}/{file_n}_{i+1}.{file_s}", transformed_image)  
                    cv2.imwrite(f"{out_mask}/{file_n}_{i+1}_mask.{file_s}", transformed_mask)  
                else:  
                    transformed = img_transform(image=image)  
                    transformed_image = transformed['image']  
                    cv2.imwrite(f"{output_path}/{file_n}_{i+1}.{file_s}", transformed_image)  
                if not shield:  
                    cv2.imwrite(f"{output_path}/{file_n}_{number_branch+1}.{file_s}", image)  
                    if mask_i: cv2.imwrite(f"{out_mask}/{file_n}_{number_branch+1}_mask.{file_s}", mask)  
  
  
def generate_datasets(train_type, dataset, seed, level, number_branch):  
  
    torch.manual_seed(seed)  
    random.seed(seed)  
    np.random.seed(seed)  
    torch.cuda.manual_seed(seed)  
  
    if train_type == "classification":  
        print('Executing data augmentation for image classification...')  
        data_path = f"./datasets/classification/{dataset}/baseline/training/"  
        folder_path = f"./datasets/classification/{dataset}/"  
        n = len([name for name in os.listdir(f"{folder_path}/baseline/training") if  
                 os.path.isdir(os.path.join(f"{folder_path}/baseline/training", name))])  
  
        for folder in ["medaugment"]:  
            shutil.copytree(f"{folder_path}baseline", f"{folder_path}{folder}",  
                            ignore=shutil.ignore_patterns("training"))  
            training_folder_path = f"{folder_path}{folder}/training"  
            os.makedirs(training_folder_path)  
            for i in range(n):  
                os.makedirs(f"{training_folder_path}/n{i}")  
  
        for i in range(n):  
            name = f"n{i}"  
            med_augment(data_path, name, level, number_branch)  
    else:  
        print('Executing data augmentation for image segmentation...')  
        data_path = f"./datasets/segmentation/{dataset}/baseline/"  
        folder_path = f"./datasets/segmentation/{dataset}/"  
  
        for folder in ["medaugment"]:  
            shutil.copytree(f"{folder_path}baseline", f"{folder_path}{folder}",  
                            ignore=shutil.ignore_patterns("training", "training_mask"))  
            os.makedirs(f"{folder_path}{folder}/training")  
            os.makedirs(f"{folder_path}{folder}/training_mask")  
  
        folder_list = ["training"]  
        for i in range(len(folder_list)):  
            name = folder_list[i]  
            med_augment(data_path, name, level, number_branch, mask_i=True)  
  
  
def main():  
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter)  
    group = parser.add_argument_group()  
    group.add_argument('--dataset', required=True)  
    group.add_argument('--train_type', choices=['classification', 'segmentation'], default='classification')  
    group.add_argument('--level', help='Augmentation level', default=5, type=int, metavar='INT')  
    group.add_argument('--number_branch', help='Number of branch', default=4, type=int, metavar='INT')  
    group.add_argument('--seed', help='Seed', default=8, type=int, metavar='INT')  
    args = parser.parse_args()  
    generate_datasets(**vars(args))  
  
  
if __name__ == '__main__':  
    main()
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
  • 88.
  • 89.
  • 90.
  • 91.
  • 92.
  • 93.
  • 94.
  • 95.
  • 96.
  • 97.
  • 98.
  • 99.
  • 100.
  • 101.
  • 102.
  • 103.
  • 104.
  • 105.
  • 106.
  • 107.
  • 108.
  • 109.
  • 110.
  • 111.
  • 112.
  • 113.
  • 114.
  • 115.
  • 116.
  • 117.
  • 118.
  • 119.
  • 120.
  • 121.
  • 122.
  • 123.
  • 124.
  • 125.
  • 126.
  • 127.
  • 128.
  • 129.
  • 130.
  • 131.
  • 132.
  • 133.
  • 134.
  • 135.
  • 136.
  • 137.
  • 138.
  • 139.
  • 140.
  • 141.
  • 142.
  • 143.
  • 144.
  • 145.
  • 146.
  • 147.
总结

今天为大家介绍了一种名为 MedAugment 的即插即用的自动数据增强方法,可以将自动数据增强引入到医学图像分析领域,显著降低了数据增强对经验的依赖。通过设计增强空间和采样策略,MedAugment 解决了自然图像和医学图像之间的差异。

然而,MedAugment 和其他最先进的方法在平衡不同评估指标方面仍存在一些问题,例如敏感性等指标可能较低。因此,可以进一步研究如何平衡不同指标。例如,可以引入和评估多个超参数,在训练过程中通过超参数更新来平衡不同指标。另外,小目标尺寸的问题需要进一步研究。这导致如何强调待分割的目标成为一个问题。进一步的研究可以根据目标的大小应用不同类型和级别的数据增强。例如,与较大的目标相比,具有较小目标的图像在增强过程中更有可能被放大,同时放大系数也可以更大。