这个类别是用来平衡不同类别样本数量的。
使用方式
在数据集外再封装一层采样方式
# 原始数据集Dataset_A 配置
data = dict(
train = dict(
# 这里是原始数据集 Dataset_A 的原始配置
type='Dataset_A',
...
pipeline=train_pipeline
)
)
...
)
# 加采样封装的配置
data = dict(
train = dict(
type='ClassBalancedDataset',
oversample_thr=1e-3, # 类别样本比例低于这个阈值就会重复采样
dataset=dict( # 这里是原始数据集 Dataset_A 的原始配置
type='Dataset_A',
...
pipeline=train_pipeline
)
)
...
)
平衡逻辑
- 自动计算各类别的样本数量,计算各类别比例f©=当前类别样本数 / 数据集总样本数
- 计算重复因子(类别维度),t是阈值,r©大于1时会重复采样。上图r©公式的含义就是类别比例小于设置的阈值时,就会重复采样。
- 计算图像维度的重复因子。如果该张图像因子大于1,则该张图像会重复
math.ceil(repeat_factor)
(向上取整)次。 m a x c ∈ L ( I ) max_{c \in L(I)} maxc∈L(I)是考虑了多标签的情况,取该图像的多个类别的类别重复因子的最大值作为该图像的重复因子。
代码解读
原始代码:https://mmclassification.readthedocs.io/zh_CN/latest/_modules/mmcls/datasets/dataset_wrappers.html#ClassBalancedDataset
@DATASETS.register_module()
class ClassBalancedDataset(object):
r"""A wrapper of repeated dataset with repeat factor.
Suitable for training on class imbalanced datasets like LVIS. Following the
sampling strategy in `this paper`_, in each epoch, an image may appear
multiple times based on its "repeat factor".
.. _this paper: https://arxiv.org/pdf/1908.03195.pdf
The repeat factor for an image is a function of the frequency the rarest
category labeled in that image. The "frequency of category c" in [0, 1]
is defined by the fraction of images in the training set (without repeats)
in which category c appears.
The dataset needs to implement :func:`self.get_cat_ids` to support
ClassBalancedDataset.
The repeat factor is computed as followed.
1. For each category c, compute the fraction :math:`f(c)` of images that
contain it.
2. For each category c, compute the category-level repeat factor
.. math::
r(c) = \max(1, \sqrt{\frac{t}{f(c)}})
3. For each image I and its labels :math:`L(I)`, compute the image-level
repeat factor
.. math::
r(I) = \max_{c \in L(I)} r(c)
Args:
dataset (:obj:`BaseDataset`): The dataset to be repeated.
oversample_thr (float): frequency threshold below which data is
repeated. For categories with ``f_c`` >= ``oversample_thr``, there
is no oversampling. For categories with ``f_c`` <
``oversample_thr``, the degree of oversampling following the
square-root inverse frequency heuristic above.
"""
def __init__(self, dataset, oversample_thr):
self.dataset = dataset
self.oversample_thr = oversample_thr
self.CLASSES = dataset.CLASSES
repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
repeat_indices = []
for dataset_index, repeat_factor in enumerate(repeat_factors):
repeat_indices.extend([dataset_index] * math.ceil(repeat_factor))
self.repeat_indices = repeat_indices
flags = []
if hasattr(self.dataset, 'flag'):
for flag, repeat_factor in zip(self.dataset.flag, repeat_factors):
flags.extend([flag] * int(math.ceil(repeat_factor)))
assert len(flags) == len(repeat_indices)
self.flag = np.asarray(flags, dtype=np.uint8)
def _get_repeat_factors(self, dataset, repeat_thr):
# 1. For each category c, compute the fraction # of images
# that contain it: f(c)
category_freq = defaultdict(int)
num_images = len(dataset)
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
assert v > 0, f'caterogy {k} does not contain any images'
category_freq[k] = v / num_images
# 2. For each category c, compute the category-level repeat factor:
# r(c) = max(1, sqrt(t/f(c)))
category_repeat = {
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
for cat_id, cat_freq in category_freq.items()
}
# 3. For each image I and its labels L(I), compute the image-level
# repeat factor:
# r(I) = max_{c in L(I)} r(c)
repeat_factors = []
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
repeat_factor = max(
{category_repeat[cat_id]
for cat_id in cat_ids})
repeat_factors.append(repeat_factor)
return repeat_factors
def __getitem__(self, idx):
ori_index = self.repeat_indices[idx]
return self.dataset[ori_index]
def __len__(self):
return len(self.repeat_indices)
def evaluate(self, *args, **kwargs):
raise NotImplementedError(
'evaluate results on a class-balanced dataset is weird. '
'Please inference and evaluate on the original dataset.')
def __repr__(self):
"""Print the number of instance number."""
dataset_type = 'Test' if self.test_mode else 'Train'
result = (
f'\n{self.__class__.__name__} ({self.dataset.__class__.__name__}) '
f'{dataset_type} dataset with total number of samples {len(self)}.'
)
return result
备注:注释中提到的要实现函数:self.get_cat_ids
,其实就是获取类别id,如果继承自BaseDataset可以不用管,已经实现好了。
# BaseDataset类的方法
def get_cat_ids(self, idx: int) -> List[int]:
"""Get category id by index.
Args:
idx (int): Index of data.
Returns:
cat_ids (List[int]): Image category of specified index.
"""
return [int(self.data_infos[idx]['gt_label'])]