文章来着第三十六届AAAI人工智能会议(AAAI-22)
概述:
主动学习(AL)的核心是应该选择哪些数据进行注释。现有的工作试图选择高度不确定或信息丰富的数据进行注释。然而,尚不清楚所选数据如何影响AL中使用的任务模型的测试性能。在本研究中,我们通过理论证明,选择梯度范数较高的未标记数据会导致测试损失的上界更低,从而获得更好的测试性能,从而探讨了这种影响。但由于缺乏标签信息,直接计算无标签数据的梯度范数是不可行的。针对这一挑战,我们提出了两种方案,即期望梯度范数和熵梯度范数。前者通过构造一个期望经验损失来计算梯度范数,后者则通过熵构造一个无监督损失。此外,我们将这两种方案集成在一个通用的人工智能框架中。我们在经典的图像分类和语义分割任务上评价了我们的方法。
介绍:
本文主要贡献有三点:首先,我们从根本上阐述了主动学习中的测试性能,发现其主要影响因素是梯度范数,梯度范数可以有效地指导无标记数据的选择。其次,我们提出了两种计算无标记数据的梯度范数的方案,而无需使用地真标记。第三,我们证明了所提出的方法在经典计算机视觉挑战和计算生物学的一个领域任务上取得了优越的性能。
选择的未标记数据如何直接影响测试性能:
根据(Koh and Liang 2017),我们知道,给定一个模型fθ,例如神经网络,从它的训练集中删除一个样本x,将大致影响测试样本xj处的损失:
其中n表示现有训练样本的数量,fθ(·)表示产生模型fθ 的全连接层输出,是所有训练样本的平均Hessian值。对于每一个训练样本,我们计算如果去掉该样本会对模型造成的影响程度,对于总影响可以根据以下公式得出:
这里我们假定样本选择发生在AL循环c次时,将其投入训练发生在C+1次循环中(即)。
可以看出,即使(
,
)可能为负值,即样本x对于测试样本xj产生不利影响,但整体来说
必定为正,直观上将,当加入一个训练样本,模型将会根据其产生的损失调整各项参数以降低损失大小,反之若去掉一个样本必然导致损失增加。
在AL中,假定第C+1次循环时测试损失为,若将训练样本移出已标记样本池,即不再参与C+1次循环,那么所造成的影响可以表示为:
我们发现,该等式直接计算是很困难的,但由于是正定的矩阵,应用Frobenius范数,我们可以推导出:
对于此式,是定值,因此
的上界取决于
,然而,直接计算该式是困难的,因为数据选择阶段C次循环中,
还没有出现,并且x的确定可信标签还没有标注。
对于没有出现的问题,我们采用
来近似替代它。
由式5可知,在循环c+1中,去除较高||∇θL(T c(x))||的训练样本x将导致L0c+1检验的较高上界。因此,在循环c + 1的标记训练池中应该保留这样的x。相反,从循环c的角度来看,在数据选择时,应选择更多||△θL(T c(x))||较高的未标记样本进行标注,并加入标记池中训练T c+1。因此,我们认为在AL中应该选择梯度范数较高的未标记数据进行注释。
因为由于缺乏x的标签信息,计算经验损失L是不可行的。为了解决这一挑战,我们提出了计算||∇θL(T c(x))||的两种方案,分别是期望梯度模和熵梯度模 。
期望梯度计算:假设在一个给定的未标记的池中有N个类,我们用yi表示第i个类的标签。注意yi是x的一个候选标签,而不是x的基本事实。对于每个样本x,它的期望损失可以通过下式:
其中P (yi|x)是使用softmax / 得到的后验值,Li是假设第i个候选标签为x的地面真相标签时的经验损失,这样可以不再需要x的确切标签。该方案可以很容易地用于分类问题,因为后验P是每个单独数据样本的单个向量。然而,对于其他问题,如语义分割(即像素级分类),该方案并不是一个理想的解决方案。这是因为在这个方案中,我们需要考虑所有单个像素的所有可能的标签,这导致了大量的可能性,这在实践中是难以处理的。为了解决这一挑战,我们在下一节提出了另一种方案来计算无标记数据的梯度范数。
熵梯度计算:在该方案中,我们使用输出熵来计算梯度范数。具体地说,我们使用网络的软最大输出的可微熵作为损失函数。由于计算熵不需要标签,因此该方案更适用于假设标签不可行的复杂任务,如语义分割任务。根据Eq. 6中P (yi|x)和N的定义,每个样本x的熵损失定义为:
主动学习框架
实验方案
为了评估所提出的方法,我们对AL设置进行了大量的实验。为了公平的比较,我们按照建议的设置和实践重新生成基线方法,例如预处理输入的方式。在除ImageNet外的所有实验中,我们运行了7个周期的AL方法,对应7个不同的注释预算(即。
从10%到40%,增量5%)。在第一个周期中,我们从未标记的池中随机选择10%的数据,并使用所选数据作为所有比较方法的初始训练数据。然后,在后续的每个周期中,我们使用特定的AL方法选择未标记的数据,并使用更新的标记池重新训练任务模型。
注意术语Cycle不同于Epoch,因为它只对应于AL中的注释预算。所有报告的结果都是3次运行的平均值。对于ImageNet,我们用5个AL循环(即注释预算从10%到30%不等)进行实验,结果在2次运行中平均,这对于非常大规模的数据集是可以接受的。
我们将我们的方法与最先进的AL基线进行比较,包括随机选择,GCN(Sequen-
tial Graph Convolutional Network for Active Learning),sraal(State-Relabeling Adversarial Active Learning),llal(Learning Loss for Active Learning),Core-Set ,mc-dropout(Dropout as A Bayesian Approximation)我们分别用exp-gn (expected-gradnorm)和entgn (entropy-gradnorm)来表示我们的方法。
实验结果及分析:(这里我们只介绍图片分类任务,语义分割不再赘述)
数据集选择:在我们的实验中,我们利用了五个广泛使用的图像分类数据集,分别是Cifar10 (Krizhevsky, Hinton et al 2009)、Cifar100 (Krizhevsky, Hinton et al 2009)、SVHN (Netzer et al 2011)、Caltech101 (Fei-Fei, Fergus,和Perona 2006)和ImageNet (Deng et al 2009)。两个Cifar数据集包括50000个训练样本和10000个测试样本,分别分布在10个和100个类中。SVHN还包含10个类,而它包含的样本比两个Cifar数据集更多,即73257个用于训练,26032个用于测试。为了与其他方法进行公平的比较,我们在SVHN中不使用额外的训练数据。Caltech101有101类更大的图像(例如300 × 200像素),这些图像不均匀地分布在类中。有些类有多达800个样本,而另一些类只包含40个样本。ImageNet是一个大型数据集,包括1000个类中约128万个训练样本。我们遵循常见的实践来报告由50000个样本组成的验证集上的模型性能。
模型使用:
我们使用ResNet-18 为除ImageNet外的所有实验的任务模型,具体来说,对于两个Cifar和SVHN数据集,由于输入维度的兼容性,我们利用了ResNet18的定制版本(GitHub - kuangliu/pytorch-cifar: 95.47% on CIFAR10 with PyTorch)。对于Caltech101,我们使用了原始的ResNet-18。为了执行公平的比较,我们对所有比较方法使用相同的任务模型。例如,(Sinha, Ebrahimi,和Darrell 2019)最初使用VGG Net作为任务模型,我们用ResNet-18替换它以保持一致性。此外,为了验证我们的方法与体系结构无关,我们使用VGG-16 (Simonyan和Zisserman 2015)作为ImageNet任务中所有方法的任务模型。
结果和分析:
如图1和图2(左)所示,我们的方法在所有数据集上都优于基线。首先,对于每个注释预算,我们的方法比其他方法获得更高的精度。这是AL方法所需的属性,因为注释预算在实际场景中可能有所不同。其次,我们的方法需要更少的标记样本来获得更好的性能。例如,在Cifar10上,我们的方法(ent-gn)在12.5K标记样本下产生了91.77%的准确性,而sraal需要多2.5K以上的样本,vaal需要多5K以上的样本来实现类似的性能。第三,在SVHN和Caltech101上的优异性能表明,我们的方法可以很好地处理不平衡数据集。第四,在ImageNet上的优越性能证明了我们的方法对于非常大规模的数据集是有效的。最后,我们观察到在Cifar10上我们的方法(ent-gn)实现了一个94.39%的精度,而用完整数据集从头开始训练相同的任务模型只能得到93.16%的结果。这一有趣的发现与(Koh and Liang 2017)中的观察一致,即一些训练数据对神经网络学习是有害的。
Koh, P . W.; and Liang, P . 2017. Understanding Black-box
Predictions via Influence Functions. In Proceedings of the
International Conference on Machine Learning, 1885–1894.
Koh, P. W.; and Liang, P. 2017. Understanding Black-box Predictions via Influence Functions. In Proceedings of the International Conference on Machine Learning, 1885–1894.