图像分类中如何使用 CutMix

drawing

使用数据集以及运行源代码:
PyTorch 图像分类与图像分割中使用 CutMix.

概述

CutMix 是效果比较好的一类数据增强,常混迹于各大视觉比赛。

那么有小伙伴问了, CutMix 是什么呢?混砍?这我知道,我最近看的扫黑风暴里就有。。。也太暴力了。

drawing

我说的是 CutMix: Regularization Strategy to Train Strong Classifierswith Localizable Features 这篇论文!

CutMix增强策略:

  • 在训练图像中剪切和粘贴补丁,其中真实标签的混合与补丁的面积成正比。
drawing drawing

这次借2021年江苏大数据开发与应用大赛(华录杯)的医疗赛道数据集来做一个对比实验,也是我初赛的方案,分享给大家,也请大佬指正错误。

这次比赛的目标是:

  • 针对胃癌病理切片,对发生癌症病变的区域进行像素级预测并对癌症类别进行分类。
  • 标签 0、1、2 分别对应正常、管状腺癌、粘液腺癌
  • 同时还需要对病灶区域进行分割,如下图:
    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zYp1bUfp-1632303136774)(attachment:12892f7a-0c03-410e-a4d3-aea44d2e798d.png)]
#  CutMix 的切块功能
def rand_bbox(size, lam):
    if len(size) == 4:
        W = size[2]
        H = size[3]
    elif len(size) == 3:
        W = size[0]
        H = size[1]
    else:
        raise Exception

    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

#  一些比较基础的数据增广,包括水平翻转、垂直翻转等
def make_transforms(phase,mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    if phase == 'train':
        transforms = albu.Compose(
            [
                albu.OneOf([            
                    albu.HorizontalFlip(p=0.5),
                    albu.VerticalFlip(p=0.5),
                    albu.Transpose(p=0.5)
                ]),
                albu.Resize(image_size,image_size, p=1),
                albu.Normalize(mean=mean, std=std, p=1),
                ToTensorV2(),
            ]
        )
    else:
        transforms = albu.Compose(
            [
                albu.Resize(image_size, image_size, p=1),
                albu.Normalize(mean=mean, std=std, p=1),
                ToTensorV2(),
            ]
        )
    return transforms

定义 PyTorch 的 Dataset

class JSHDataset(Dataset):
    
    def __init__(self, df, transforms, train=False):
        self.df = df
        self.transforms = transforms
        self.train = train
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        fn = row.image_name
        # 读取图片数据
        image = cv2.imread(os.path.join(row['image_path'], fn))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_LINEAR)
        # 读取 mask 数据
        masks = cv2.imread(os.path.join(row['mask_path'], fn), cv2.IMREAD_GRAYSCALE)/255
        masks = cv2.resize(masks, (1024, 1024), interpolation=cv2.INTER_LINEAR)
        # 读取 label
        label = torch.zeros(3)
        label[row.label] = 1
        # ------------------------------  CutMix  ------------------------------------------
        prob = 20  # 将 prob 设置为 0 即可关闭 CutMix
        if random.randint(0, 99) < prob and self.train:
            rand_index = random.randint(0, len(self.df) - 1)

            rand_row = self.df.iloc[rand_index]
            rand_fn = rand_row.image_name

            rand_image = cv2.imread(os.path.join(rand_row['image_path'], rand_fn))
            rand_image = cv2.cvtColor(rand_image, cv2.COLOR_BGR2RGB)
            rand_image = cv2.resize(rand_image, (1024, 1024), interpolation=cv2.INTER_LINEAR)

            rand_masks = cv2.imread(os.path.join(rand_row['mask_path'], rand_fn), cv2.IMREAD_GRAYSCALE)/255
            rand_masks = cv2.resize(rand_masks, (1024, 1024), interpolation=cv2.INTER_LINEAR)

            lam = np.random.beta(1,1)
            bbx1, bby1, bbx2, bby2 = rand_bbox(image.shape, lam)

            image[bbx1:bbx2, bby1:bby2, :] = rand_image[bbx1:bbx2, bby1:bby2, :]
            masks[bbx1:bbx2, bby1:bby2] = rand_masks[bbx1:bbx2, bby1:bby2]

            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (image.shape[1] * image.shape[0]))

            rand_label = torch.zeros(3)
            rand_label[rand_row.label] = 1

            label = label * lam + rand_label * (1. - lam)
        # ---------------------------------  CutMix  ---------------------------------------
        # 应用之前我们定义的各种数据增广
        augmented = self.transforms(image=image, mask=masks)
        img, mask = augmented['image'], augmented['mask']
        return img, label, mask.unsqueeze(0)

    def __len__(self):
        return len(self.df)

使用 PyTorch 的 Dataloader 创建数据的生成器

trainset = JSHDataset(train_df, make_transforms('train'), train=True)
valset = JSHDataset(val_df, make_transforms('val'))

train_loader = DataLoader(
    trainset,
    batch_size=batch_size,
    num_workers=8,
    shuffle=True,  # shuffle 是比较简单的打乱数据,如果在处理数据不均衡的数据集可以使用 sampler
    pin_memory=True
)
val_loader = DataLoader(
    valset,
    batch_size=batch_size,
    num_workers=8,
    pin_memory=True
)

可视化

drawing

下面对原始数据和进行了 CutMix 操作的数据分别进行可视化。

random_list = [random.randint(0, len(trainset)-1) for i in range(5)]
f, ax = plt.subplots(2, 5, figsize=(14,4))
for i in range(5):
    img, _, mask = trainset.__getitem__(random_list[i])
    ax[0][i].imshow(
        torch.clip(img.permute(1,2,0), 0, 1)
    );
    ax[1][i].imshow(
        torch.clip(mask.permute(1,2,0), 0, 1)
    );    

在这里插入图片描述

# 使用了 CutMix 之后的数据进行可视化,可以明显的看到数据中的“补丁”
random_list = [random.randint(0, len(trainset)-1) for i in range(5)]
f, ax = plt.subplots(2, 5, figsize=(14,4))
for i in range(5):
    img, _, mask = trainset.__getitem__(random_list[i])
    ax[0][i].imshow(
        torch.clip(img.permute(1,2,0), 0, 1)
    );
    ax[1][i].imshow(
        torch.clip(mask.permute(1,2,0), 0, 1)
    );    

在这里插入图片描述

模型配置

ENCODER_WEIGHTS = ‘imagenet’
ACTIVATION = None
DEVICE = ‘cuda’
n_class = 1

创建分类的 head

aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    activation=None,           # activation function, default is None
    classes=3,                 # define number of output labels
)

创建分割的 head 同时载入模型预训练权重

model = smp.FPN(
    ENCODER,
    classes=n_class,
    encoder_weights=ENCODER_WEIGHTS,
    activation=ACTIVATION,
    aux_params=aux_params
).cuda()

学习率、优化器

learning_rate = 3e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, amsgrad=False)

创建前向传播

bce_loss = torch.nn.BCEWithLogitsLoss().cuda()

def forward_fn(trainer, data):
    images, labels, masks = data
    images, labels, masks = images.cuda(), labels.cuda(), masks.cuda()

    mask, label = model(images)

    loss = bce_loss(label, labels) + 3*bce_loss(mask, masks)

    return (label, mask), loss

训练器的配置

# 训练器的配置
miner = minetorch.Miner(
   code=os.getenv('CODE', f'fold-{fold}'),
   alchemistic_directory=os.getenv('ALCHEMISTIC_DIRECTORY', f'{Folder}{ENCODER}-cutmix-{image_size}'),
   model=model,
   forward=forward_fn,
   optimizer=optimizer,
   train_dataloader=train_loader,
   val_dataloader=val_loader,
   max_epochs=100,
   in_notebook=True,
   loss_func=None,
   amp=True,
   plugins=[MultiClassesSegmentationMetricWithLogic()],
   trival=False,
   resume=False
)

drawing

把两个实验的 Loss 画到一张图上来对比一下。

可以清楚的看到 CutMix 操作在同样训练参数的情况下对 Validation 的 Loss 效果更优。

以上就是 CutMix 在图像分类以及分割当中的应用,请小伙伴们参考,如有错误也请大佬指正
在这里插入图片描述
image

  • 4
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
好的,下面是基于PyTorch和CutMix进行数据增强和绘制GradCAM热力图的全部代码: ```python import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models, transforms from PIL import Image import matplotlib.pyplot as plt class CutMixTransform: """ CutMix augmentation Reference: https://arxiv.org/pdf/1905.04899.pdf """ def __init__(self, alpha=1.0): self.alpha = alpha def __call__(self, image, target): r = np.random.rand(1) if r < 0.5: return image, target w, h = image.size cut_rat = np.sqrt(1. - self.alpha) cut_w = np.int(w * cut_rat) cut_h = np.int(h * cut_rat) cx = np.random.randint(w) cy = np.random.randint(h) bbx1 = np.clip(cx - cut_w // 2, 0, w) bby1 = np.clip(cy - cut_h // 2, 0, h) bbx2 = np.clip(cx + cut_w // 2, 0, w) bby2 = np.clip(cy + cut_h // 2, 0, h) image = image.copy() image.paste(image.crop((bbx1, bby1, bbx2, bby2)), (bbx1, bby1, bbx2, bby2)) target_ = target.copy() target = [target, target_] return image, target class Model(nn.Module): """ Pretrained ResNet50 model for image classification """ def __init__(self, num_classes): super().__init__() self.resnet = models.resnet50(pretrained=True) self.resnet.fc = nn.Linear(2048, num_classes) def forward(self, x): x = self.resnet.conv1(x) x = self.resnet.bn1(x) x = self.resnet.relu(x) x = self.resnet.maxpool(x) x = self.resnet.layer1(x) x = self.resnet.layer2(x) x = self.resnet.layer3(x) x = self.resnet.layer4(x) x = self.resnet.avgpool(x) x = torch.flatten(x, 1) x = self.resnet.fc(x) return x def inference(model, image_path): """ Perform inference on single image """ image = Image.open(image_path) transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) input_tensor = transform(image).unsqueeze(0) output_tensor = model(input_tensor) output_probs = F.softmax(output_tensor, dim=1) output_label = torch.argmax(output_probs, dim=1) return input_tensor, output_probs, output_label def gradcam(model, input_tensor, class_idx): """ Calculate GradCAM heatmap """ model.eval() feature_maps, logits = model(input_tensor.cuda()) logits[0, class_idx].backward() gradients = model.resnet.layer4[2].conv3.weight.grad pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]) feature_maps = feature_maps.permute(0, 2, 3, 1) heatmap = torch.zeros_like(feature_maps[:, :, :, 0]) for i in range(pooled_gradients.shape[0]): heatmap += (pooled_gradients[i] * feature_maps[:, :, :, i]) return heatmap # Load model model = Model(num_classes=10) # Load image and perform inference image_path = "cat.jpg" input_tensor, output_probs, output_label = inference(model, image_path) # Choose a random class label to visualize class_idx = torch.randint(0, 10, size=(1,))[0].item() class_name = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"][class_idx] # Perform CutMix augmentation and inference cutmix_transforms = CutMixTransform(alpha=1.0) image_cutmix, target_cutmix = cutmix_transforms(Image.open(image_path).convert("RGB"), class_idx) input_tensor_cutmix, output_probs_cutmix, output_label_cutmix = inference(model, image_cutmix) # Calculate GradCAM heatmap heatmap = gradcam(model, input_tensor_cutmix.cuda(), target_cutmix) # Plot original image and GradCAM heatmap fig, ax = plt.subplots(1, 2, figsize=(15, 5)) ax[0].imshow(Image.open(image_path)) ax[0].set_title(f"True: cat\nPred: {class_name}") ax[1].imshow(heatmap.detach().cpu().numpy(), cmap="jet") ax[1].set_title(f"GradCAM for {class_name}") plt.show() ``` 在这个示例,我们首先定义了一个`CutMixTransform`类来实现CutMix数据增强,然后定义了一个预训练的ResNet50模型用于图像分类。我们对一张测试图像进行预测,并从预测结果随机选择一个类别,然后利用`CutMixTransform`对原始图像进行数据增强,并在增强后的图像上进行预测。然后,我们将增强后的图像和原始图像都用于计算GradCAM热力图,并展示出来。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Dave 扫地工

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值