Abstract
训练集类别机器不均衡,而测试的标准又要求对每一个类别都有很好的泛化能力,甚至,更关心少数类的表现,是 long tail 问题的本质矛盾。这篇文章提出了两个方法:1)label-distribution-aware margin(LDAM),最小化边缘泛化边界。2)一种简单但是有效的训练方式,先让模型学习初始的特征表示(initial stage),再进行re-weighting或者re-sampling。
Introduction
这几天有同学问我到底什么是long tail,我就说就是训练集不平衡但是测试集是平衡了,搞得别人一头雾水:这玩意有啥做的意义,咱们平时用的数据集不都是平衡的么?你还人为的造出来不平衡的训练集,没事找事...
啊这...
其实这个问题可以结合一下研究现状来回答,可以看到主要的贡献都是FAIR,SenseTime或者Google的,主要原因还是产业上会遇到这个问题。在实际收集数据的时候,人们不可能为了平衡数据集主动放弃大量的常见类,而只标注稀有类,这是不经济的。又或者,其实研究者对某些类的分类准确率(或者其他性能指标)不感兴趣,他们只关心特定的几类准确率。
举个例子,病灶检测。如果一个数据集中,95%是正常区域(多数类,head),剩下5%是5种不同的病灶,那么模型只预测为正常就可以达到95%的准确率,然而这是无用的,我们关心的是这5种病灶(稀有类,tail)的检出率。所以,此时tail类的准确率能否得到提升就是模型性能的关键。
当然,long tail和异常检测还是有不同的,以上例子只是为了说明long tail问题的一个意义所在。因为这篇文章是2019年CVPR的(现在CVPR2021都开始冲了,听说1w多篇,吐了),所以简介就不怎么说了,可以参考之前的文章。
Implementing this general idea requires a data-dependent or label-dependent regularizer while regularization depends only on the weight matrices
文章从正则化的角度考虑,认为l2正则不太行,提出了一个新的不但依赖于数据而且还和其label有关的正则方式。
Encouraging a large margin can be viewed as regularization, as standard generalization error bounds depend on the inverse of the minimum margin among all the examples.
核心思想是让少数类之间的margin尽可能地大。
我们以二分类为例,讲一讲margin对分类结果的直观影响:
假设一个二分类问题,一类是多数一类是少数,一个良好的分类器是找到一个分类边界,使得数据的分布被划分到两个域,从而实现完美的分类表现。一般的,margin定义为第i类中每个数据点到达分类边界的最小值。
一个不均衡的二分类例子
以上图为例,蓝点是多数类,绿点是少数类,两条线虚线分别是两个类中到边界的最小数据点,点虚线是理想的分类边界,位于两条线虚线的中间。对于一个分类器来说,位于分类边界(点虚线)左边的都是多数类,右边的都是少数类,两个类的margin都是相等的。
但是在不均衡分类问题中,模型对少数类的特征表示往往学习不够,容易过拟合从而导致少数类上很差的表现。所以考虑对少数类放低要求,从margin的角度来说,就是给予少数类更大的margin,从上图可以看出,实际的分类边界向多数类的线虚线偏移,从而对减轻少数类分类的难度。此时,两个类的margin不相等了,图中以 表示。
那么偏移多少合适呢?直观上想,偏移量应该是和每一类样本数量有关的一个值,具体的推导可以看下面的理论部分。
Related Work
- Re-sampling
- Re-weighting
- Margin loss:
Hinge loss 通常用于获得“margin”的分类器,尤其是在SVM方法中。 最近,已经提出了大边距Softmax,Angular Softmax和加性余量Softmax,以最大程度地减少类内变化的预测和通过合并角度余量的概念来扩大类间余量。 与这些论文中与类别无关的margin相反,此文的方法尽力为少数群体提供更大的margin。
- Label shift in domain adaptation
不均衡学习有时候可以看做迁移学习或着Domain adaptation中的label shift问题。在label shift中,最大的困难在于检测和估计标签的偏移,并且在这之后,应用重新加权或重新采样恢复。在long tail问题中,标签偏移是已知的,那么能否基于此做出更好的re-weight或者re-sample优化?
- Meta-learning
这个我的理解就是把一些超参(比如采样分布)也当作参数,用某些特定设计的loss在特定的epoch进行一定的更新。
Main Approach
假设一个模型为f,其在均衡数据集上的损失一般可以表示为:
其中,l和y是指输入x的预测标签和真实标签,这个公式的含义就是真实标签y的logit小于其他标签的logit的概率。
在一个样本(x,y)中,其真实标签 y=j ,则定义margin,:
这样,综合考虑数据集中所有y=j的样本(记作 ),定义每一类的margin:
上述假设所有训练样本分类完全正确,这在训练集是可以做到的。至此,定义和之前的研究完全一致。
理论推导先不写,直接上结论:
最佳的trade off:
其中nj是对应类别j的数目,C是一个超参常数。
基于hinge loss,新的loss为:
其中:
对应到softmax:
可以看到不同的label对应的margin是不一样的,和他们类的数量有关,但是值得注意的是,只有y_true减去了margin。
这个方法虽然改了loss,但是loss还是可以做re-weight的。
Deferred Re-balancing Optimization Schedule
如图,从T0之后的epoch都对loss按照每一类的频率做reweight,easy,故略。
Experiment
19年的实验结果就不看了吧,不过这个实验里好几个数据都持续出现在2020年的研究工作中。
重点关注一下Cifar100(Res32&Imbalance Factor=100)的结果。
LDAM的代码
class LDAMLoss(nn.Module):
def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30):
super(LDAMLoss, self).__init__()
m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) # nj的四次开方
m_list = m_list * (max_m / np.max(m_list)) # 常系数 C
m_list = torch.cuda.FloatTensor(m_list) # 转成 tensor
self.m_list = m_list
assert s > 0
self.s = s # 这个参数的作用论文里提过么?
self.weight = weight # 和频率相关的 re-weight
def forward(self, x, target):
index = torch.zeros_like(x, dtype=torch.uint8) # 和 x 维度一致全 0 的tensor
index.scatter_(1, target.data.view(-1, 1), 1) # dim idx input
index_float = index.type(torch.cuda.FloatTensor) # 转 tensor
''' 以上的idx指示的应该是一个batch的y_true '''
batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0,1))
batch_m = batch_m.view((-1, 1))
x_m = x - batch_m # y 的 logit 减去 margin
output = torch.where(index, x_m, x) # 按照修改位置合并
return F.cross_entropy(self.s*output, target, weight=self.weight)
关于参数cls_num_list是什么,如何得到?
在train.py文件中,出现了cls_num_list,是从训练集中获取的
cls_num_list = train_dataset.get_cls_num_list()
print('cls num list:')
print(cls_num_list)
args.cls_num_list = cls_num_list
论文中有使用两个自制的不平衡数据集cifar-10和cifar-100,通过函数IMBALANCECIFAR获得的
if args.dataset == 'cifar10':
train_dataset = IMBALANCECIFAR10(root='./data', imb_type=args.imb_type, imb_factor=args.imb_factor, rand_number=args.rand_number, train=True, download=True, transform=transform_train)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_val)
elif args.dataset == 'cifar100':
train_dataset = IMBALANCECIFAR100(root='./data', imb_type=args.imb_type, imb_factor=args.imb_factor, rand_number=args.rand_number, train=True, download=True, transform=transform_train)
# train_dataset
val_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_val)
else:
warnings.warn('Dataset is not listed')
return
作者通过INBALANCE函数来获取不平衡数据集
def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
img_max = len(self.data) / cls_num
img_num_per_cls = []
if imb_type == 'exp':
for cls_idx in range(cls_num):
num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
img_num_per_cls.append(int(num))
elif imb_type == 'step':
for cls_idx in range(cls_num // 2):
img_num_per_cls.append(int(img_max))
for cls_idx in range(cls_num // 2):
img_num_per_cls.append(int(img_max * imb_factor))
else:
img_num_per_cls.extend([int(img_max)] * cls_num)
return img_num_per_cls
其中,get_img_num_per_cls函数是为了获取每一类图像的样本数量。首先,函数会根据类别数量和数据集总数计算出每个类别应该有的最大图片数量。然后,根据imb_type参数的不同,使用不同的算法来计算每个类别的具体图片数量。如果imb_type是’exp’,则使用指数衰减的方式计算;如果imb_type是’step’,则使用阶梯式计算;如果imb_type既不是’exp’也不是’step’,则所有类别使用相同数量的图片。
def gen_imbalanced_data(self, img_num_per_cls):
new_data = []
new_targets = []
targets_np = np.array(self.targets, dtype=np.int64)
classes = np.unique(targets_np)
# np.random.shuffle(classes)
self.num_per_cls_dict = dict()
for the_class, the_img_num in zip(classes, img_num_per_cls):
self.num_per_cls_dict[the_class] = the_img_num
idx = np.where(targets_np == the_class)[0]
np.random.shuffle(idx)
selec_idx = idx[:the_img_num]
new_data.append(self.data[selec_idx, ...])
new_targets.extend([the_class, ] * the_img_num)
new_data = np.vstack(new_data)
self.data = new_data
self.targets = new_targets
gen_imbalanced_data函数是一个用于生成不平衡数据集的方法。它接受一个参数img_num_per_cls,该参数是一个列表,指定每个类别需要生成的图像数量。该方法首先将目标标签转换成numpy数组,并提取出所有不重复的类别。接着,对每一个类别,根据img_num_per_cls中指定的数量随机选择对应的样本,并将这些样本添加到新的数据集中。最后,将新的数据集和标签更新到原始数据集中。
def get_cls_num_list(self):
cls_num_list = []
for i in range(self.cls_num):
cls_num_list.append(self.num_per_cls_dict[i])
return cls_num_list
get_cls_num_list函数通过遍历获取每一类图像的数量。在函数内部,它通过循环遍历self.cls_num,并将self.num_per_cls_dict中每个类别的数量添加到cls_num_list中。
参考资料:
知乎:LDAM loss
论文阅读Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss