mmcls ClassBalancedDataset类详解

5 篇文章 0 订阅
3 篇文章 0 订阅

这个类别是用来平衡不同类别样本数量的。

使用方式

在数据集外再封装一层采样方式

# 原始数据集Dataset_A 配置
data = dict(
    train = dict(
      	# 这里是原始数据集 Dataset_A 的原始配置
	      type='Dataset_A',
	       ...
	       pipeline=train_pipeline
        )
    )
    ...
)
# 加采样封装的配置
data = dict(
    train = dict(
        type='ClassBalancedDataset',
        oversample_thr=1e-3, # 类别样本比例低于这个阈值就会重复采样
        dataset=dict(  # 这里是原始数据集 Dataset_A 的原始配置
            type='Dataset_A',
            ...
            pipeline=train_pipeline
        )
    )
    ...
)

平衡逻辑

在这里插入图片描述

  1. 自动计算各类别的样本数量,计算各类别比例f©=当前类别样本数 / 数据集总样本数
  2. 计算重复因子(类别维度),t是阈值,r©大于1时会重复采样。上图r©公式的含义就是类别比例小于设置的阈值时,就会重复采样。
  3. 计算图像维度的重复因子。如果该张图像因子大于1,则该张图像会重复math.ceil(repeat_factor)(向上取整)次。 m a x c ∈ L ( I ) max_{c \in L(I)} maxcL(I)是考虑了多标签的情况,取该图像的多个类别的类别重复因子的最大值作为该图像的重复因子。

代码解读

原始代码https://mmclassification.readthedocs.io/zh_CN/latest/_modules/mmcls/datasets/dataset_wrappers.html#ClassBalancedDataset

@DATASETS.register_module()
class ClassBalancedDataset(object):
    r"""A wrapper of repeated dataset with repeat factor.

    Suitable for training on class imbalanced datasets like LVIS. Following the
    sampling strategy in `this paper`_, in each epoch, an image may appear
    multiple times based on its "repeat factor".

    .. _this paper: https://arxiv.org/pdf/1908.03195.pdf

    The repeat factor for an image is a function of the frequency the rarest
    category labeled in that image. The "frequency of category c" in [0, 1]
    is defined by the fraction of images in the training set (without repeats)
    in which category c appears.

    The dataset needs to implement :func:`self.get_cat_ids` to support
    ClassBalancedDataset.
	
    The repeat factor is computed as followed.

    1. For each category c, compute the fraction :math:`f(c)` of images that
       contain it.
    2. For each category c, compute the category-level repeat factor

        .. math::
            r(c) = \max(1, \sqrt{\frac{t}{f(c)}})

    3. For each image I and its labels :math:`L(I)`, compute the image-level
       repeat factor

        .. math::
            r(I) = \max_{c \in L(I)} r(c)

    Args:
        dataset (:obj:`BaseDataset`): The dataset to be repeated.
        oversample_thr (float): frequency threshold below which data is
            repeated. For categories with ``f_c`` >= ``oversample_thr``, there
            is no oversampling. For categories with ``f_c`` <
            ``oversample_thr``, the degree of oversampling following the
            square-root inverse frequency heuristic above.
    """

    def __init__(self, dataset, oversample_thr):
        self.dataset = dataset
        self.oversample_thr = oversample_thr
        self.CLASSES = dataset.CLASSES

        repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
        repeat_indices = []
        for dataset_index, repeat_factor in enumerate(repeat_factors):
            repeat_indices.extend([dataset_index] * math.ceil(repeat_factor))
        self.repeat_indices = repeat_indices

        flags = []
        if hasattr(self.dataset, 'flag'):
            for flag, repeat_factor in zip(self.dataset.flag, repeat_factors):
                flags.extend([flag] * int(math.ceil(repeat_factor)))
            assert len(flags) == len(repeat_indices)
        self.flag = np.asarray(flags, dtype=np.uint8)

    def _get_repeat_factors(self, dataset, repeat_thr):
        # 1. For each category c, compute the fraction # of images
        #   that contain it: f(c)
        category_freq = defaultdict(int)
        num_images = len(dataset)
        for idx in range(num_images):
            cat_ids = set(self.dataset.get_cat_ids(idx))
            for cat_id in cat_ids:
                category_freq[cat_id] += 1
        for k, v in category_freq.items():
            assert v > 0, f'caterogy {k} does not contain any images'
            category_freq[k] = v / num_images

        # 2. For each category c, compute the category-level repeat factor:
        #    r(c) = max(1, sqrt(t/f(c)))
        category_repeat = {
            cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
            for cat_id, cat_freq in category_freq.items()
        }

        # 3. For each image I and its labels L(I), compute the image-level
        # repeat factor:
        #    r(I) = max_{c in L(I)} r(c)
        repeat_factors = []
        for idx in range(num_images):
            cat_ids = set(self.dataset.get_cat_ids(idx))
            repeat_factor = max(
                {category_repeat[cat_id]
                 for cat_id in cat_ids})
            repeat_factors.append(repeat_factor)

        return repeat_factors

    def __getitem__(self, idx):
        ori_index = self.repeat_indices[idx]
        return self.dataset[ori_index]

    def __len__(self):
        return len(self.repeat_indices)

    def evaluate(self, *args, **kwargs):
        raise NotImplementedError(
            'evaluate results on a class-balanced dataset is weird. '
            'Please inference and evaluate on the original dataset.')

    def __repr__(self):
        """Print the number of instance number."""
        dataset_type = 'Test' if self.test_mode else 'Train'
        result = (
            f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
            f'{dataset_type} dataset with total number of samples {len(self)}.'
        )
        return result

备注:注释中提到的要实现函数:self.get_cat_ids,其实就是获取类别id,如果继承自BaseDataset可以不用管,已经实现好了。

# BaseDataset类的方法
    def get_cat_ids(self, idx: int) -> List[int]:
        """Get category id by index.

        Args:
            idx (int): Index of data.

        Returns:
            cat_ids (List[int]): Image category of specified index.
        """

        return [int(self.data_infos[idx]['gt_label'])]
  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值