【数据增强】CutMix

方法来源

CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
在这里插入图片描述

简单来说cutmix相当于cutout+mixup的结合,可以应用于各种任务中。

参考代码

import numpy as np
import torch

def rand_bbox(size, lam):
    """生成 CutMix 的随机边界框。"""
    W = size[2]
    H = size[3]
    cut_w = int(W * np.sqrt(1 - lam))
    cut_h = int(H * np.sqrt(1 - lam))

    # 为边界框生成随机中心
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    # 确保边界框在图像范围内
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def cumix_datas(datas, labels, alpha):
    """应用 CutMix 数据增强。"""
    lam = np.random.beta(alpha, alpha)
    rand_idx = torch.randperm(datas.size(0))  # 随机打乱索引
    bbx1, bby1, bbx2, bby2 = rand_bbox(datas.size(), lam)

    shuffle_datas = datas[rand_idx]  # 打乱数据
    # 正确引用打乱的标签
    shuffle_labels = labels[rand_idx]  # 打乱标签

    # 混合图像
    datas[:, :, bbx1:bbx2, bby1:bby2] = shuffle_datas[:, :, bbx1:bbx2, bby1:bby2]
    
    # 计算用于损失计算的调整 lambda
    lam = 1 - (bbx2 - bbx1) * (bby2 - bby1) / (datas.size(2) * datas.size(3))

    labels_a = labels
    labels_b = shuffle_labels
    return datas, labels_a, labels_b, lam

参考文章

【论文阅读笔记】CutMix:数据增强
CutMix&Mixup详解与代码实战
高阶数据增强:Cutmix 原理讲解&零基础程序实现

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Bosenya12

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值