引言
文本数据增强技术在小样本分类任务上,有助于模型效果的提升。已有的数据增强技术如EDA、Conditional BERT采用的是局部替换的方式,在预训练语言模型效果显著提升的背景下,作者试图采用GPT2,以文本生成的方式合成新样本,而非局部替换的方式,探索提升模型效果的方法。
LAMBADA
LAMBADA(Language Model Based Data Augmentation)是一种基于语言模型的数据增强方法,具体详情如下:
输入部分需要“训练数据集合,分类算法、预训练语言模型、每个类别需合成的样本量”四个部分。“每个类别需合成的样本量”可以看作超参数。
第一步,使用分类算法在训练集合上训练得到一个基线分类器;理论上任意分类器均是可行的。
第二步使用预训练语言模型在训练集上fine-tune,使预训练语言模型学习训练数据的分布,该步骤是最非常重要的一步。论文中采用GPT2在训练数据上进行语言模型任务,为了保证语言模型在预测词时,考虑到语序的标签信息,训练语言模型时,输入采用如下形式:
y
1
S
E
P
x
1
E
O
S
y
2
S
E
P
x
2
E
O
S
.
.
.
.
.
.
y
n
S
E
P
x
n
E
O
S
y_1 SEP x_1 EOS y_2 SEP x_2 EOS...... y_n SEP x_n EOS
y1SEPx1EOSy2SEPx2EOS......ynSEPxnEOS
第三步用fine-tune后的语言模型生成合成的数据;为了使语言模型生成的文本序列包含特定的标签信息,在语言模型推断时,输入以 y i S E P y_iSEP yiSEP开始,直至 E O S EOS EOS结束。这种方式的优势就是直接使用了分类标签的语义信息。
第四步,使用基线分类器过滤合成的数据集,得到高质量的合成数据集;筛选条件主要参考三个因素:合成样本的真实标签、基线分类器的预测标签及预测概率,真实标签对于的样本数量,筛选出置合适的样本,作为最终的合成数据集。
第五步,将原始训练集与合成的数据作为新训练集,在新训练集上重新训练,得到最终的分类器。第一步至第五步这个过程可以不断重复,但重复该过程可能会导致数据漂移(data drifting)问题—— 合成的样本具有主导影响力。论文没有给出对应的实验结果。
实验结果
采用LAMBADA比不采用,在三种分类器上均有提升,并且在样本量较小时,效果更明显。
表明LAMBADA方法在实验数据集上,比EDA、CBERT效果更好。提升的原因可能是因为文本生成方式的多样性比局部替换的多样性更好。